Skip to content

Instantly share code, notes, and snippets.

@jacobsapps
Created June 4, 2025 06:23
Show Gist options
  • Save jacobsapps/4400f88163e94ca7b3508774f98e7581 to your computer and use it in GitHub Desktop.
Save jacobsapps/4400f88163e94ca7b3508774f98e7581 to your computer and use it in GitHub Desktop.
CoreMLProcessing.swift
import Vision
import UIKit
import ImageIO
struct TitleEmbedding: Codable {
let fullNameCopy: String
let keyword: String
let embedding: [Float]
}
struct StatEmbedding: Codable {
let label: String
let embedding: [Float]
}
struct TypeEmbedding: Codable, Identifiable {
var id: String { name }
let type: String
let name: String
let emoji: String
let color: String
let hex: String
let embedding: [Float]
}
final class ImagePredictor {
private var model: VNCoreMLModel? = nil
private var statsEmbeddings: [StatEmbedding] = []
private var titleEmbeddings: [TitleEmbedding] = []
private var typeEmbeddings: [TypeEmbedding] = []
func loadCLIPModelAndEmbeddings() {
model = nil
statsEmbeddings = []
titleEmbeddings = []
typeEmbeddings = []
let defaultConfig = MLModelConfiguration()
guard let modelURL = Bundle.main.url(forResource: "mobileclip_blt_image", withExtension: "mlmodelc") else {
fatalError("Model file not found in the bundle.")
}
let imageClassifierModel: MLModel
do {
imageClassifierModel = try MLModel(contentsOf: modelURL, configuration: defaultConfig)
} catch {
fatalError("Failed to load the model from file: \(error.localizedDescription)")
}
guard let imageClassifierVisionModel = try? VNCoreMLModel(for: imageClassifierModel) else {
fatalError("App failed to create a `VNCoreMLModel` instance.")
}
self.model = imageClassifierVisionModel
do {
guard let titleURL = Bundle.main.url(forResource: "titles_embeddings", withExtension: "json") else {
fatalError("Embedding file titles_embeddings.json not found.")
}
let titleData = try Data(contentsOf: titleURL)
titleEmbeddings = try JSONDecoder().decode([TitleEmbedding].self, from: titleData)
guard let statsURL = Bundle.main.url(forResource: "stats_embeddings", withExtension: "json") else {
fatalError("Embedding file stats_embeddings.json not found.")
}
let statsData = try Data(contentsOf: statsURL)
statsEmbeddings = try JSONDecoder().decode([StatEmbedding].self, from: statsData)
guard let typeURL = Bundle.main.url(forResource: "type_embeddings", withExtension: "json") else {
fatalError("Embedding file type_embeddings.json not found.")
}
let typeData = try Data(contentsOf: typeURL)
typeEmbeddings = try JSONDecoder().decode([TypeEmbedding].self, from: typeData)
} catch {
print(error)
}
}
func makePredictions(for photo: UIImage, name: String) async throws -> CardStats? {
guard let model else { return nil }
let orientation = CGImagePropertyOrientation(photo.imageOrientation)
guard let photoImage = photo.cgImage else {
throw NSError(domain: "ImagePredictor", code: 1, userInfo: [NSLocalizedDescriptionKey: "Photo doesn't have underlying CGImage."])
}
return try await withCheckedThrowingContinuation { continuation in
let handler = VNImageRequestHandler(cgImage: photoImage, orientation: orientation)
let request = VNCoreMLRequest(model: model) { request, error in
if let error = error {
continuation.resume(throwing: error)
return
}
if let cardStats = self.handleCLIPPrediction(from: request, name: name) {
continuation.resume(returning: cardStats)
} else {
let description = "Unexpected result type: \(String(describing: request.results))"
continuation.resume(throwing: NSError(domain: "ImagePredictor", code: 2, userInfo: [NSLocalizedDescriptionKey: description]))
}
}
request.imageCropAndScaleOption = .centerCrop
do {
try handler.perform([request])
} catch {
continuation.resume(throwing: error)
}
}
}
private func handleCLIPPrediction(from request: VNRequest, name: String) -> CardStats? {
guard let result = request.results?.first as? VNCoreMLFeatureValueObservation,
let imageEmbedding = result.featureValue.multiArrayValue else {
return nil
}
let floatArray: [Float] = (0..<imageEmbedding.count).map {
Float(truncating: imageEmbedding[$0])
}
let norm = sqrt(floatArray.reduce(0) { $0 + $1 * $1 })
let normalized = floatArray.map { $0 / norm }
let stats = predictStats(from: normalized)
let title = predictTitle(from: normalized, name)
let type = predictType(from: normalized)
let rarity = checkRarity(from: normalized)
return CardStats(title: title, type: type, stats: stats, rarity: rarity)
}
private func predictStats(from imageEmbedding: [Float]) -> [Stat] {
guard !statsEmbeddings.isEmpty else { return [] }
var rankedStats = statsEmbeddings
.map { ($0.label, dot(imageEmbedding, $0.embedding)) }
.sorted(by: { $0.1 > $1.1 })
let attackIndex = rankedStats.firstIndex(where: { $0.0 == "Attack" }) ?? 0
let attack = rankedStats.remove(at: attackIndex)
let defenceIndex = rankedStats.firstIndex(where: { $0.0 == "Defence" }) ?? 0
let defence = rankedStats.remove(at: defenceIndex)
let magicIndex = rankedStats.firstIndex(where: { $0.0 == "Magic" }) ?? 0
let magic = rankedStats.remove(at: magicIndex)
let topThreeStats = rankedStats.prefix(3).map { ($0.0, $0.1 * 0.66) }
let stats = [attack, defence, magic] + topThreeStats
return stats.map { label, confidence in
let roundedValue = (Double(abs(confidence)) * 120_000).rounded(to: 100)
return Stat(name: label, value: roundedValue)
}
}
private func predictTitle(from imageEmbedding: [Float], _ name: String) -> String {
guard !titleEmbeddings.isEmpty else { return "" }
let topTitle = titleEmbeddings
.map { ($0.fullNameCopy.replacingOccurrences(of: "X", with: name), dot(imageEmbedding, $0.embedding)) }
.sorted(by: { $0.1 > $1.1 })
return topTitle.first?.0 ?? ""
}
private func predictType(from imageEmbedding: [Float]) -> ElementalType {
guard !typeEmbeddings.isEmpty else { return .init(emoji: "", hexClor: "") }
// 1 in 4 chance to return a random type
if Int.random(in: 0..<4) == 0 {
let randomType = typeEmbeddings.randomElement()!
return ElementalType(emoji: randomType.emoji, hexClor: randomType.hex)
}
let scoredTypes = typeEmbeddings.map { type -> (TypeEmbedding, Float) in
var score = dot(imageEmbedding, type.embedding)
if type.name == "Arcane" {
score -= 0.45
}
return (type, score)
}
let topType = scoredTypes
.sorted(by: { $0.1 > $1.1 })
.first?.0
guard let topTypeItem = topType else {
return .init(emoji: "", hexClor: "")
}
return ElementalType(emoji: topTypeItem.emoji, hexClor: topTypeItem.hex)
}
private func dot(_ a: [Float], _ b: [Float]) -> Float {
zip(a, b).reduce(0) { $0 + $1.0 * $1.1 }
}
private func checkRarity(from imageEmbedding: [Float]) -> Rarity {
guard let first = imageEmbedding.first else {
return .common
}
let decimalString = String(format: "%.12f", first)
if let lastDigitChar = decimalString.last,
let digit = Int(String(lastDigitChar)) {
switch digit {
case 9:
return .secretRare
case 7...8:
return .ultraRare
case 4...6:
return .rare
default:
return .common
}
} else {
return .common
}
}
}
private extension Double {
func rounded(to nearest: Double) -> Double {
(self / nearest).rounded() * nearest
}
}
private extension CGImagePropertyOrientation {
/// Converts an image orientation to a Core Graphics image property orientation.
/// - Parameter orientation: A `UIImage.Orientation` instance.
///
/// The two orientation types use different raw values.
init(_ orientation: UIImage.Orientation) {
switch orientation {
case .up: self = .up
case .down: self = .down
case .left: self = .left
case .right: self = .right
case .upMirrored: self = .upMirrored
case .downMirrored: self = .downMirrored
case .leftMirrored: self = .leftMirrored
case .rightMirrored: self = .rightMirrored
@unknown default: self = .up
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment