Skip to content

Instantly share code, notes, and snippets.

@uvolchyk
Last active February 21, 2025 12:59
Show Gist options
  • Save uvolchyk/8ce59857f800ca5cc939ef55150823ac to your computer and use it in GitHub Desktop.
Save uvolchyk/8ce59857f800ca5cc939ef55150823ac to your computer and use it in GitHub Desktop.
import MetalKit
struct ConstantBuffer<T> {
let length: Int
var data: UnsafeMutablePointer<T>
var position: Int
static var stepSize: Int {
MemoryLayout<T>.stride
}
init (_ buffer: MTLBuffer) {
let dataPtr = buffer.contents()
let floatPtr = dataPtr.bindMemory(
to: T.self,
capacity: buffer.length / Self.stepSize
)
self.init(
buffer: floatPtr,
elementsCount: buffer.length / Self.stepSize
)
}
init(
buffer: UnsafeMutablePointer<T>,
elementsCount: Int
) {
data = buffer
length = elementsCount
position = .zero
}
var availableSpace: UInt {
UInt(length - position)
}
func hasSpace(for count: UInt) -> Bool {
availableSpace >= count
}
}
extension ConstantBuffer {
mutating func write(
value: inout T,
instance: Int = 0
) {
withUnsafeBytes(of: &value) { bytes in
data[position] = bytes.load(as: T.self)
}
position = position &+ 1
}
mutating func append(_ value: T) {
guard hasSpace(for: 1) else { return }
appendRaw(value)
}
mutating func appendRaw(_ value: T) {
data[position] = value
position = position &+ 1
}
}
extension ConstantBuffer where T: SIMDScalar {
mutating func append(_ vector: SIMD3<T>) {
append(vector.x)
append(vector.y)
append(vector.z)
}
mutating func append(_ vector: SIMD4<T>) {
append(vector.x)
append(vector.y)
append(vector.z)
append(vector.w)
}
}
#include <metal_stdlib>
using namespace metal;
namespace Firework {
struct Uniform {
packed_float4 color;
float4x4 transform;
float time;
};
struct VertexOut {
float4 position [[position]];
float4 color;
float progress;
float time;
};
vertex VertexOut vertexScene(
uint vid [[ vertex_id ]],
uint iid [[ instance_id ]],
constant packed_float4* position [[ buffer(0) ]],
constant float* progress [[ buffer(1) ]],
constant Uniform* uniform [[ buffer(2) ]]
) {
VertexOut out;
out.position = uniform[iid].transform * float4(position[vid]);
out.color = uniform[iid].color;
out.progress = progress[vid];
out.time = uniform[iid].time;
return out;
}
float rand(float2 n) {
return fract(sin(dot(n, n)) * length(n));
}
float noise(float2 n) {
const float2 d = float2(0.0, 1.0);
float2 b = floor(n);
float2 f = smoothstep(float2(0.0), float2(1.0), fract(n));
return mix(
mix(rand(b), rand(b + d.yx), f.x),
mix(rand(b + d.xy), rand(b + d.yy), f.x),
f.y
);
}
fragment float4 fragmentScene(
VertexOut in [[stage_in]]
) {
const float breakpoint = 0.5;
const float noiseScale = 0.8;
const float noiseFalloff = 3.0;
const float bloomIntensity = 2.0;
const float bloomFalloff = 20.0;
const float time = in.time;
const float dissolveProgress = in.progress;
const bool isSecondPhase = time > breakpoint;
const float normalizedTime = isSecondPhase ? (time - breakpoint) / (1.0 - breakpoint) : time / breakpoint;
const float normalizedProgress = isSecondPhase ? (dissolveProgress - breakpoint) / (1.0 - breakpoint) : dissolveProgress / breakpoint;
if (normalizedProgress > normalizedTime) discard_fragment();
float _noise = 1.0 - noise(in.position.xy * noiseScale);
float delayPeriod = isSecondPhase > 0.0 ? 0.0 : 0.25;
float delayFactor = (pow(normalizedTime, noiseFalloff * (isSecondPhase > 0.0 ? 1.0 : 1.3)) - delayPeriod) / (1.0 - delayPeriod);
float noiseFade = smoothstep(0.0, normalizedTime, normalizedProgress - delayFactor);
if (_noise > noiseFade) discard_fragment();
float colorFactor = clamp(1.0 - (normalizedTime - normalizedProgress), 0.0, 1.0);
float bloomEffect = bloomIntensity * pow(colorFactor, bloomFalloff);
float bloomFade = 1.0 - smoothstep(0.9, 1.0, normalizedTime);
bloomEffect *= bloomFade;
float3 finalColor = in.color.rgb + in.color.rgb * bloomEffect;
return float4(finalColor, 1.0);
}
struct CompositionOut {
float4 position [[position]];
float2 texCoord;
};
vertex CompositionOut vertexComposition(
constant float4* vertexData [[ buffer(0) ]],
uint vertexID [[vertex_id]]
) {
CompositionOut out;
float2 position = vertexData[vertexID].xy;
float2 texCoord = vertexData[vertexID].zw;
out.position = float4(position, 0.0, 1.0);
out.texCoord = texCoord;
return out;
}
fragment float4 fragmentComposition(
CompositionOut in [[stage_in]],
texture2d<float> sceneTexture [[texture(0)]],
texture2d<float> bloomTexture [[texture(1)]]
) {
constexpr sampler textureSampler (mag_filter::linear, min_filter::linear);
float4 sceneColor = sceneTexture.sample(textureSampler, in.texCoord);
float4 bloomColor = bloomTexture.sample(textureSampler, in.texCoord);
float bloomIntensity = 3.0;
return sceneColor + bloomColor * bloomIntensity;
}
}
import simd
struct FireworkScene {
struct Uniform {
let color: SIMD4<Float>
let transform: simd_float4x4
let time: Float
}
var trailStartPoint: SIMD2<Float>
var trailEndPoint: SIMD2<Float>
var trailControlPoint: SIMD2<Float>
var ratio: Float
var segments: Int = 8
var trailWidth: Float = 0.05
var burstPalmTrails: Int = 16
var burstPalmRadius: Float = 0.7
}
extension FireworkScene {
func draw(
vertexBuffer: inout ConstantBuffer<Float>,
progressBuffer: inout ConstantBuffer<Float>
) {
drawLaunchTrail(
vertexBuffer: &vertexBuffer,
progressBuffer: &progressBuffer
)
drawBurstPalm(
vertexBuffer: &vertexBuffer,
progressBuffer: &progressBuffer
)
}
}
private extension FireworkScene {
func drawLaunchTrail(
vertexBuffer: inout ConstantBuffer<Float>,
progressBuffer: inout ConstantBuffer<Float>
) {
let tIncrement = 1.0 / Float(segments)
var previousPoints: [SIMD2<Float>] = []
for i in 0...segments {
let t = Float(i) * tIncrement
let currentPoint = quadraticBezier(
trailStartPoint,
trailControlPoint,
trailEndPoint,
t
)
let tangent = quadraticBezierTangent(
trailStartPoint,
trailControlPoint,
trailEndPoint,
t
)
let normalizedTangent = normalize(tangent)
let normal = SIMD2<Float>(
-normalizedTangent.y,
normalizedTangent.x
)
let offset = -trailWidth / 2
let p1 = currentPoint + normal * offset
let p2 = currentPoint - normal * offset
guard i > 0 else {
previousPoints = [p1, p2]
continue
}
let p_current = 0.5 * Float(i) / (Float(segments))
let p_previous = 0.5 * Float(i - 1) / (Float(segments))
appendVertex(previousPoints[0], to: &vertexBuffer)
appendVertex(p1, to: &vertexBuffer)
appendVertex(previousPoints[1], to: &vertexBuffer)
progressBuffer.appendRaw(p_previous)
progressBuffer.appendRaw(p_current)
progressBuffer.appendRaw(p_previous)
appendVertex(previousPoints[1], to: &vertexBuffer)
appendVertex(p1, to: &vertexBuffer)
appendVertex(p2, to: &vertexBuffer)
progressBuffer.appendRaw(p_previous)
progressBuffer.appendRaw(p_current)
progressBuffer.appendRaw(p_current)
previousPoints = [p1, p2]
}
}
private func drawBurstPalm(
vertexBuffer: inout ConstantBuffer<Float>,
progressBuffer: inout ConstantBuffer<Float>
) {
let tIncrement = 1.0 / Float(segments)
let gravity: Float = 0.8
for i in 0..<burstPalmTrails {
let angle = (Float.pi * 2 * Float(i)) / Float(burstPalmTrails)
let initialVelocity = SIMD2<Float>(
cos(angle),
sin(angle)
) * burstPalmRadius
let trailStart = trailEndPoint
var previousPoints: [SIMD2<Float>] = []
for j in 0...segments {
let t = Float(j) * tIncrement
let currentPoint = SIMD2<Float>(
trailStart.x + initialVelocity.x * t,
trailStart.y + initialVelocity.y * t - (gravity * t * t) / 2
)
let tangent = SIMD2<Float>(
initialVelocity.x,
initialVelocity.y - gravity * t
)
let normalizedTangent = normalize(tangent)
let normal = SIMD2<Float>(-normalizedTangent.y, normalizedTangent.x)
let offset = -trailWidth / 2
let p1 = currentPoint + normal * offset
let p2 = currentPoint - normal * offset
guard j > 0 else {
previousPoints = [p1, p2]
continue
}
let p_current = 0.5 * Float(j) / Float(segments) + 0.5
let p_previous = 0.5 * Float(j - 1) / Float(segments) + 0.5
appendVertex(previousPoints[0], to: &vertexBuffer)
appendVertex(p1, to: &vertexBuffer)
appendVertex(previousPoints[1], to: &vertexBuffer)
progressBuffer.appendRaw(p_previous)
progressBuffer.appendRaw(p_current)
progressBuffer.appendRaw(p_previous)
appendVertex(previousPoints[1], to: &vertexBuffer)
appendVertex(p1, to: &vertexBuffer)
appendVertex(p2, to: &vertexBuffer)
progressBuffer.appendRaw(p_previous)
progressBuffer.appendRaw(p_current)
progressBuffer.appendRaw(p_current)
previousPoints = [p1, p2]
}
}
}
}
private extension FireworkScene {
func quadraticBezier(
_ p0: SIMD2<Float>,
_ p1: SIMD2<Float>,
_ p2: SIMD2<Float>,
_ t: Float
) -> SIMD2<Float> {
let oneMinusT = 1 - t
return oneMinusT * oneMinusT * p0 + 2 * oneMinusT * t * p1 + t * t * p2
}
func quadraticBezierTangent(
_ p0: SIMD2<Float>,
_ p1: SIMD2<Float>,
_ p2: SIMD2<Float>,
_ t: Float
) -> SIMD2<Float> {
let oneMinusT = 1 - t
return 2 * oneMinusT * (p1 - p0) + 2 * t * (p2 - p1)
}
func appendVertex(
_ point: SIMD2<Float>,
to buffer: inout ConstantBuffer<Float>
) {
buffer.appendRaw(point.x * ratio)
buffer.appendRaw(point.y)
buffer.appendRaw(0.0)
buffer.appendRaw(1.0)
}
}
import MetalKit
import MetalPerformanceShaders
import SwiftUI
#if os(macOS)
typealias ViewController = NSViewController
#else
typealias ViewController = UIViewController
#endif
final class FireworkViewController: ViewController {
private var device: (any MTLDevice)!
private var commandQueue: (any MTLCommandQueue)!
private var scenePipelineState: (any MTLRenderPipelineState)!
private var vertexBuffer: (any MTLBuffer)!
private var progressBuffer: (any MTLBuffer)!
private var uniformBuffer: (any MTLBuffer)!
private var sceneTexture: (any MTLTexture)!
private var glowTexture: (any MTLTexture)!
private lazy var canvasView = MTKView()
let duration: Float = 3.0
private let initialTime = CACurrentMediaTime()
private var compositePipelineState: (any MTLRenderPipelineState)!
private var sceneBuffer: (any MTLBuffer)!
}
extension FireworkViewController {
override func loadView() {
view = canvasView
}
override func viewDidLoad() {
super.viewDidLoad()
device = MTLCreateSystemDefaultDevice()
commandQueue = device.makeCommandQueue()
canvasView.device = device
canvasView.delegate = self
buildPipelineStates()
buildBuffers()
}
#if os(macOS)
override func viewDidLayout() {
super.viewDidLayout()
buildResources(size: canvasView.bounds.size)
}
#else
override func viewDidLayoutSubviews() {
super.viewDidLayoutSubviews()
buildResources(size: canvasView.bounds.size)
}
#endif
}
extension FireworkViewController: MTKViewDelegate {
func draw(in view: MTKView) {
var vBufferWrapper = ConstantBuffer<Float>(vertexBuffer)
var pBufferWrapper = ConstantBuffer<Float>(progressBuffer)
var uBufferWrapper = ConstantBuffer<FireworkScene.Uniform>(uniformBuffer)
let firework = FireworkScene(
trailStartPoint: SIMD2<Float>(-0.4, -0.8),
trailEndPoint: SIMD2<Float>(0.1, 0.5),
trailControlPoint: SIMD2<Float>(-0.3, -0.1),
ratio: Float(view.bounds.height / view.bounds.width),
trailWidth: 0.03
)
let elapsedTime = Float(CACurrentMediaTime() - initialTime).truncatingRemainder(dividingBy: duration)
let t1 = easeOutCubic(elapsedTime / duration)
var uniform1 = FireworkScene.Uniform(
color: SIMD4<Float>(208.0 / 255.0, 80.0 / 255.0, 111.0 / 255.0, 1.0),
transform: .init(translationX: 0.1, y: 0.0),
time: t1
)
uBufferWrapper.write(value: &uniform1)
let t2 = easeOutCubic((elapsedTime - 0.2) / duration)
var uniform2 = FireworkScene.Uniform(
color: SIMD4<Float>(100.0 / 255.0, 167.0 / 255.0, 230.0 / 255.0, 1.0),
transform: .init(translationX: -0.1, y: 0.0),
time: t2
)
uBufferWrapper.write(value: &uniform2)
let t3 = easeOutCubic((elapsedTime - 0.4) / duration)
var uniform3 = FireworkScene.Uniform(
color: SIMD4<Float>(90.0 / 255.0, 98.0 / 255.0, 198.0 / 255.0, 1.0),
transform: .init(translationX: 0.0, y: 0.2),
time: t3
)
uBufferWrapper.write(value: &uniform3)
firework.draw(
vertexBuffer: &vBufferWrapper,
progressBuffer: &pBufferWrapper
)
guard
let commandBuffer = commandQueue.makeCommandBuffer(),
let drawable = view.currentDrawable
else { return }
let vertexCount = vBufferWrapper.position / 4
if
let renderPassDescriptor = sceneRenderPassDescriptor(sceneTexture),
let renderEncoder = commandBuffer.makeRenderCommandEncoder(
descriptor: renderPassDescriptor
)
{
renderEncoder.setRenderPipelineState(scenePipelineState)
renderEncoder.setVertexBuffer(
vertexBuffer,
offset: 0,
index: 0
)
renderEncoder.setVertexBuffer(
progressBuffer,
offset: 0,
index: 1
)
renderEncoder.setVertexBuffer(
uniformBuffer,
offset: 0,
index: 2
)
renderEncoder.drawPrimitives(
type: .triangle,
vertexStart: 0,
vertexCount: vertexCount,
instanceCount: 3
)
renderEncoder.endEncoding()
}
let kernel = MPSImageGaussianBlur(
device: device,
sigma: 40.0
)
kernel.encode(
commandBuffer: commandBuffer,
sourceTexture: sceneTexture,
destinationTexture: glowTexture
)
if
let renderPassDescriptor = sceneRenderPassDescriptor(drawable.texture),
let renderEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor)
{
renderEncoder.setRenderPipelineState(compositePipelineState)
renderEncoder.setVertexBuffer(sceneBuffer, offset: 0, index: 0)
renderEncoder.setFragmentTexture(sceneTexture, index: 0)
renderEncoder.setFragmentTexture(glowTexture, index: 1)
renderEncoder.drawPrimitives(
type: .triangleStrip,
vertexStart: 0,
vertexCount: 4
)
renderEncoder.endEncoding()
}
commandBuffer.present(drawable)
commandBuffer.commit()
}
func mtkView(
_ view: MTKView,
drawableSizeWillChange size: CGSize
) {
buildResources(size: size)
}
}
private extension FireworkViewController {
func buildPipelineStates() {
guard
let library = device.makeDefaultLibrary()
else {
return
}
let scenePipelineDescriptor = MTLRenderPipelineDescriptor()
scenePipelineDescriptor.vertexFunction = library.makeFunction(name: "Firework::vertexScene")
scenePipelineDescriptor.fragmentFunction = library.makeFunction(name: "Firework::fragmentScene")
scenePipelineDescriptor.colorAttachments[0].pixelFormat = .bgra8Unorm
scenePipelineDescriptor.colorAttachments[0].isBlendingEnabled = true
scenePipelineDescriptor.colorAttachments[0].rgbBlendOperation = .add
scenePipelineDescriptor.colorAttachments[0].alphaBlendOperation = .add
scenePipelineDescriptor.colorAttachments[0].sourceRGBBlendFactor = .sourceAlpha
scenePipelineDescriptor.colorAttachments[0].sourceAlphaBlendFactor = .sourceAlpha
scenePipelineDescriptor.colorAttachments[0].destinationRGBBlendFactor = .oneMinusSourceAlpha
scenePipelineDescriptor.colorAttachments[0].destinationAlphaBlendFactor = .oneMinusSourceAlpha
do {
scenePipelineState = try device.makeRenderPipelineState(descriptor: scenePipelineDescriptor)
} catch {
fatalError("Failed to create scene pipeline state: \(error)")
}
let compositePipelineDescriptor = MTLRenderPipelineDescriptor()
compositePipelineDescriptor.vertexFunction = library.makeFunction(
name: "Firework::vertexComposition"
)
compositePipelineDescriptor.fragmentFunction = library.makeFunction(
name: "Firework::fragmentComposition"
)
compositePipelineDescriptor.colorAttachments[0].pixelFormat = .bgra8Unorm
do {
compositePipelineState = try device.makeRenderPipelineState(
descriptor: compositePipelineDescriptor
)
} catch {
fatalError("Failed to create composite pipeline state: \(error)")
}
}
func buildBuffers() {
vertexBuffer = device.makeBuffer(
length: 10 * 1000 * 1000
)
progressBuffer = device.makeBuffer(
length: 10 * 1000 * 1000
)
uniformBuffer = device.makeBuffer(
length: 10 * 1000 * 1000
)
let quadVertices: [Float] = [
-1.0, 1.0, 0.0, 0.0,
1.0, 1.0, 1.0, 0.0,
-1.0, -1.0, 0.0, 1.0,
1.0, -1.0, 1.0, 1.0,
]
sceneBuffer = device.makeBuffer(
bytes: quadVertices,
length: MemoryLayout<Float>.stride * quadVertices.count,
options: []
)
}
func buildResources(size: CGSize) {
let width = Int(size.width)
let height = Int(size.height)
// Create texture descriptors
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: .bgra8Unorm,
width: width,
height: height,
mipmapped: false
)
textureDescriptor.usage = [.renderTarget, .shaderRead, .shaderWrite]
sceneTexture = device.makeTexture(descriptor: textureDescriptor)
glowTexture = device.makeTexture(descriptor: textureDescriptor)
}
func sceneRenderPassDescriptor(_ texture: any MTLTexture) -> MTLRenderPassDescriptor? {
let descriptor = MTLRenderPassDescriptor()
descriptor.colorAttachments[0].texture = texture
descriptor.colorAttachments[0].loadAction = .clear
descriptor.colorAttachments[0].storeAction = .store
descriptor.colorAttachments[0].clearColor = MTLClearColorMake(0, 0, 0, 1)
return descriptor
}
}
#if os(macOS)
struct FireworkView: NSViewControllerRepresentable {
func makeNSViewController(context: Context) -> some NSViewController {
FireworkViewController()
}
func updateNSViewController(
_ nsViewController: NSViewControllerType,
context: Context
) {}
}
#else
struct FireworkView: UIViewControllerRepresentable {
func makeUIViewController(context: Context) -> some UIViewController {
FireworkViewController()
}
func updateUIViewController(
_ uiViewController: UIViewControllerType,
context: Context
) {}
}
#endif
@uvolchyk
Copy link
Author

drawn-multiple.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment