Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save acsr/8f23377f1319533b123d45b73c8b4c2e to your computer and use it in GitHub Desktop.
Save acsr/8f23377f1319533b123d45b73c8b4c2e to your computer and use it in GitHub Desktop.
RealityKit Morph To Sphere with LowLevelMesh and LowLevelTexture
#include <metal_stdlib>
using namespace metal;
#include "MorphToSphereParams.h"
kernel void fadeTextureToWhite(texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]],
constant ProcessTextureParams& params [[buffer(0)]],
uint2 gid [[thread_position_in_grid]])
{
float4 color = inTexture.read(gid);
// Define the target color that all pixels will approach
float3 targetColor = float3(1.0, 1.0, 1.0); // white
// Interpolate between original color and target color based on progress
color.rgb = mix(color.rgb, targetColor, params.progress);
outTexture.write(color, gid);
}
import SwiftUI
import RealityKit
import Metal
struct MorphModelToSphereView: View {
@State var entity: ModelEntity?
@State var lowLevelMesh: LowLevelMesh?
@State var originalVertices: [VertexData] = []
@State var originalTexture: LowLevelTexture?
@State var processedTexture: LowLevelTexture?
@State var timer: Timer?
@State var isForward: Bool = true
@State var morphProgress: Float = 0.0
@State var dwellCounter: Int = 0
@State var isDwelling: Bool = false
let timerUpdateDuration: TimeInterval = 1/120.0
let dwellDuration: Int = 60 // frames to dwell (0.5 seconds at 120fps)
var morphProgressRate: Float = 0.01
let device: MTLDevice
let commandQueue: MTLCommandQueue
let computePipelineState: MTLComputePipelineState
let textureComputePipeline: MTLComputePipelineState
var morphRadius: Float {
if let entity = entity, let bounds = entity.model?.mesh.bounds {
return bounds.boundingRadius * 0.375
}
return 0.5 // just some default
}
init() {
self.device = MTLCreateSystemDefaultDevice()!
self.commandQueue = device.makeCommandQueue()!
let library = device.makeDefaultLibrary()!
let kernelFunction = library.makeFunction(name: "morphVerticesToSphere")!
self.computePipelineState = try! device.makeComputePipelineState(function: kernelFunction)
let updateFunction = library.makeFunction(name: "fadeTextureToWhite")!
self.textureComputePipeline = try! device.makeComputePipelineState(function: updateFunction)
}
var body: some View {
RealityView { content in
let model = try! await loadModelEntity()
content.add(model)
let lowLevelMesh = try! createMesh(from: model)
// Store original vertex data
lowLevelMesh.withUnsafeBytes(bufferIndex: 0) { buffer in
let vertices = buffer.bindMemory(to: VertexData.self)
self.originalVertices = Array(vertices)
}
// swap out model mesh with our LowLevelMesh
model.model?.mesh = try! await MeshResource(from: lowLevelMesh)
// Store original texture as a LowLevelTexture (maybe theres a better type here?)
var material = model.model?.materials.first as! PhysicallyBasedMaterial
let baseColorTexture = material.baseColor.texture!.resource
let originalTexture = try! copyTextureResourceToLowLevelTexture(from: baseColorTexture)
// Swap out model material using another LowLevelTexture we will modify
let processedTexture = try! copyTextureResourceToLowLevelTexture(from: baseColorTexture)
let newTextureResource = try! await TextureResource(from: processedTexture)
material.baseColor.texture = .init(newTextureResource)
material.metallic = 1.0
material.roughness = 0.125
updateTexture()
model.model?.materials = [material]
self.entity = model
// Just keeping the entire model visible in preview
entity?.scale *= 0.9
self.lowLevelMesh = lowLevelMesh
self.originalTexture = originalTexture
self.processedTexture = processedTexture
}
.onAppear { startTimer() }
.onDisappear { stopTimer() }
}
func startTimer() {
timer = Timer.scheduledTimer(withTimeInterval: timerUpdateDuration, repeats: true) { timer in
if isDwelling {
// Count dwell time
dwellCounter += 1
if dwellCounter >= dwellDuration {
// Finished dwelling, switch direction and resume morphing
isDwelling = false
dwellCounter = 0
isForward.toggle()
}
} else {
// Update Morph Progress
if isForward {
morphProgress += morphProgressRate
} else {
morphProgress -= morphProgressRate
}
// Handle bounds
if morphProgress >= 1.0 {
morphProgress = 1.0
isDwelling = true
dwellCounter = 0
} else if morphProgress <= 0.0 {
morphProgress = 0.0
isDwelling = true
dwellCounter = 0
}
}
updateMesh()
updateTexture()
}
}
func stopTimer() {
timer?.invalidate()
timer = nil
}
enum MeshCreationError: Error {
case modelNotFound, meshPartNotFound
}
}
// MARK: Download model
extension MorphModelToSphereView {
func loadModelEntity(url: URL = ExampleModels.crystal.url) async throws -> ModelEntity {
let (downloadedURL, _) = try await URLSession.shared.download(from: url)
let documentsDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
let destinationURL = documentsDirectory.appendingPathComponent("downloadedModel.usdz")
if FileManager.default.fileExists(atPath: destinationURL.path) {
try FileManager.default.removeItem(at: destinationURL)
}
try FileManager.default.moveItem(at: downloadedURL, to: destinationURL)
let entity = try await ModelEntity.init(contentsOf: destinationURL)
try FileManager.default.removeItem(at: destinationURL)
return entity
}
}
// MARK: Mesh functions
extension MorphModelToSphereView {
func createMesh(from modelEntity: ModelEntity) throws -> LowLevelMesh {
guard let model = modelEntity.model
else { throw MeshCreationError.modelNotFound }
guard let meshPart = model.mesh.contents.models.first?.parts.first
else { throw MeshCreationError.meshPartNotFound}
let positions = meshPart[MeshBuffers.positions]?.elements ?? []
let normals = meshPart[MeshBuffers.normals]?.elements ?? []
let textureCoordinates = meshPart[MeshBuffers.textureCoordinates]?.elements ?? []
let triangleIndices = meshPart.triangleIndices?.elements ?? []
let lowLevelMesh = try VertexData.initializeMesh(vertexCapacity: positions.count,
indexCapacity: triangleIndices.count)
// Copy vertex data
lowLevelMesh.withUnsafeMutableBytes(bufferIndex: 0) { buffer in
let vertices = buffer.bindMemory(to: (SIMD3<Float>, SIMD3<Float>, SIMD2<Float>).self)
for i in 0..<positions.count {
vertices[i] = (positions[i], normals[i], textureCoordinates[i])
}
}
// Copy index data
lowLevelMesh.withUnsafeMutableIndices { buffer in
let indices = buffer.bindMemory(to: UInt32.self)
for (index, triangleIndex) in triangleIndices.enumerated() {
indices[index] = UInt32(triangleIndex)
}
}
// Set up parts
let bounds = model.mesh.bounds
lowLevelMesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: triangleIndices.count,
topology: .triangle,
bounds: bounds
)
])
return lowLevelMesh
}
func updateMesh() {
guard let mesh = lowLevelMesh,
let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
// Reset mesh to original state
mesh.withUnsafeMutableBytes(bufferIndex: 0) { buffer in
let vertices = buffer.bindMemory(to: VertexData.self)
for i in 0..<originalVertices.count {
vertices[i] = originalVertices[i]
}
}
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)
computeEncoder.setComputePipelineState(computePipelineState)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
var params = MorphToSphereParams(radius: morphRadius, progress: morphProgress)
computeEncoder.setBytes(&params, length: MemoryLayout<MorphToSphereParams>.stride, index: 1)
let threadsPerGrid = MTLSize(width: mesh.vertexCapacity, height: 1, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 64, height: 1, depth: 1)
computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
computeEncoder.endEncoding()
commandBuffer.commit()
}
}
// MARK: Texture functions
extension MorphModelToSphereView {
func copyTextureResourceToLowLevelTexture(from textureResource: TextureResource) throws -> LowLevelTexture {
var descriptor = LowLevelTexture.Descriptor()
descriptor.textureType = .type2D
descriptor.pixelFormat = .rgba16Float
descriptor.width = textureResource.width
descriptor.height = textureResource.height
descriptor.mipmapLevelCount = 1
descriptor.textureUsage = [.shaderRead, .shaderWrite]
let texture = try LowLevelTexture(descriptor: descriptor)
try textureResource.copy(to: texture.read())
return texture
}
func updateTexture() {
guard let original = originalTexture, let faded = processedTexture else { return }
let commandBuffer = commandQueue.makeCommandBuffer()!
let computeEncoder = commandBuffer.makeComputeCommandEncoder()!
computeEncoder.setComputePipelineState(textureComputePipeline)
computeEncoder.setTexture(original.read(), index: 0)
computeEncoder.setTexture(faded.replace(using: commandBuffer), index: 1)
var params = ProcessTextureParams(progress: morphProgress)
computeEncoder.setBytes(&params, length: MemoryLayout<ProcessTextureParams>.size, index: 0)
let threadGroupSize = MTLSizeMake(8, 8, 1)
let threadGroups = MTLSizeMake(
(original.descriptor.width + threadGroupSize.width - 1) / threadGroupSize.width,
(original.descriptor.height + threadGroupSize.height - 1) / threadGroupSize.height,
1
)
computeEncoder.dispatchThreadgroups(threadGroups, threadsPerThreadgroup: threadGroupSize)
computeEncoder.endEncoding()
commandBuffer.commit()
}
}
#Preview {
MorphModelToSphereView()
}
extension VertexData {
static var vertexAttributes: [LowLevelMesh.Attribute] = [
.init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!),
.init(semantic: .normal, format: .float3, offset: MemoryLayout<Self>.offset(of: \.normal)!),
.init(semantic: .uv0, format: .float2, offset: MemoryLayout<Self>.offset(of: \.uv)!)
]
static var vertexLayouts: [LowLevelMesh.Layout] = [
.init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride)
]
static var descriptor: LowLevelMesh.Descriptor {
var desc = LowLevelMesh.Descriptor()
desc.vertexAttributes = VertexData.vertexAttributes
desc.vertexLayouts = VertexData.vertexLayouts
desc.indexType = .uint32
return desc
}
@MainActor static func initializeMesh(vertexCapacity: Int,
indexCapacity: Int) throws -> LowLevelMesh {
var desc = VertexData.descriptor
desc.vertexCapacity = vertexCapacity
desc.indexCapacity = indexCapacity
return try LowLevelMesh(descriptor: desc)
}
}
enum ExampleModels {
case crystal
case buddha
case rock
case laserGun
static let baseURL = URL(string: "https://matt54.github.io/Resources/")!
var url: URL {
return ExampleModels.baseURL.appendingPathComponent( "\(filename).usdz" )
}
var filename: String {
switch self {
case .crystal:
return "Crystal_1"
case .buddha:
return "StatueOfBuddha"
case .rock:
return "Rock_1"
case .laserGun:
return "laser_gun"
}
}
}
#ifndef MorphToSphereParams_h
#define MorphToSphereParams_h
struct MorphToSphereParams {
float radius;
float progress;
};
struct ProcessTextureParams {
float progress;
};
#endif /* MorphToSphereParams_h */
#include <metal_stdlib>
using namespace metal;
#include "VertexData.h"
#include "MorphToSphereParams.h"
kernel void morphVerticesToSphere(device VertexData* vertices [[buffer(0)]],
constant MorphToSphereParams& params [[buffer(1)]],
uint vid [[thread_position_in_grid]])
{
float3 originalPosition = vertices[vid].position;
float3 originalNormal = vertices[vid].normal;
// Calculate direction from center to vertex position
float distance = length(originalPosition);
float3 spherePosition = originalPosition;
float3 sphereNormal = originalNormal;
// Calculate the normalized sphere position and normal
if (distance > 0.0001) { // prevent divide-by-zero
// Normalize the direction to get the sphere surface position
float3 normalizedDirection = originalPosition / distance;
spherePosition = normalizedDirection * params.radius;
sphereNormal = normalizedDirection;
}
// Interpolate based on progress
vertices[vid].position = mix(originalPosition, spherePosition, params.progress);
float3 interpolatedNormal = mix(originalNormal, sphereNormal, params.progress);
vertices[vid].normal = normalize(interpolatedNormal);
}
#include <simd/simd.h>
#ifndef VertexData_h
#define VertexData_h
struct VertexData {
simd_float3 position;
simd_float3 normal;
simd_float2 uv;
};
#endif /* PlaneVertex_h */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment