Created
June 4, 2025 06:23
-
-
Save jacobsapps/086d5733e35227e36b3088279c05fe69 to your computer and use it in GitHub Desktop.
CoreMLProcessing.swift
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Vision | |
import UIKit | |
import ImageIO | |
struct TitleEmbedding: Codable { | |
let fullNameCopy: String | |
let keyword: String | |
let embedding: [Float] | |
} | |
struct StatEmbedding: Codable { | |
let label: String | |
let embedding: [Float] | |
} | |
struct TypeEmbedding: Codable, Identifiable { | |
var id: String { name } | |
let type: String | |
let name: String | |
let emoji: String | |
let color: String | |
let hex: String | |
let embedding: [Float] | |
} | |
final class ImagePredictor { | |
private var model: VNCoreMLModel? = nil | |
private var statsEmbeddings: [StatEmbedding] = [] | |
private var titleEmbeddings: [TitleEmbedding] = [] | |
private var typeEmbeddings: [TypeEmbedding] = [] | |
func loadCLIPModelAndEmbeddings() { | |
model = nil | |
statsEmbeddings = [] | |
titleEmbeddings = [] | |
typeEmbeddings = [] | |
let defaultConfig = MLModelConfiguration() | |
guard let modelURL = Bundle.main.url(forResource: "mobileclip_blt_image", withExtension: "mlmodelc") else { | |
fatalError("Model file not found in the bundle.") | |
} | |
let imageClassifierModel: MLModel | |
do { | |
imageClassifierModel = try MLModel(contentsOf: modelURL, configuration: defaultConfig) | |
} catch { | |
fatalError("Failed to load the model from file: \(error.localizedDescription)") | |
} | |
guard let imageClassifierVisionModel = try? VNCoreMLModel(for: imageClassifierModel) else { | |
fatalError("App failed to create a `VNCoreMLModel` instance.") | |
} | |
self.model = imageClassifierVisionModel | |
do { | |
guard let titleURL = Bundle.main.url(forResource: "titles_embeddings", withExtension: "json") else { | |
fatalError("Embedding file titles_embeddings.json not found.") | |
} | |
let titleData = try Data(contentsOf: titleURL) | |
titleEmbeddings = try JSONDecoder().decode([TitleEmbedding].self, from: titleData) | |
guard let statsURL = Bundle.main.url(forResource: "stats_embeddings", withExtension: "json") else { | |
fatalError("Embedding file stats_embeddings.json not found.") | |
} | |
let statsData = try Data(contentsOf: statsURL) | |
statsEmbeddings = try JSONDecoder().decode([StatEmbedding].self, from: statsData) | |
guard let typeURL = Bundle.main.url(forResource: "type_embeddings", withExtension: "json") else { | |
fatalError("Embedding file type_embeddings.json not found.") | |
} | |
let typeData = try Data(contentsOf: typeURL) | |
typeEmbeddings = try JSONDecoder().decode([TypeEmbedding].self, from: typeData) | |
} catch { | |
print(error) | |
} | |
} | |
func makePredictions(for photo: UIImage, name: String) async throws -> CardStats? { | |
guard let model else { return nil } | |
let orientation = CGImagePropertyOrientation(photo.imageOrientation) | |
guard let photoImage = photo.cgImage else { | |
throw NSError(domain: "ImagePredictor", code: 1, userInfo: [NSLocalizedDescriptionKey: "Photo doesn't have underlying CGImage."]) | |
} | |
return try await withCheckedThrowingContinuation { continuation in | |
let handler = VNImageRequestHandler(cgImage: photoImage, orientation: orientation) | |
let request = VNCoreMLRequest(model: model) { request, error in | |
if let error = error { | |
continuation.resume(throwing: error) | |
return | |
} | |
if let cardStats = self.handleCLIPPrediction(from: request, name: name) { | |
continuation.resume(returning: cardStats) | |
} else { | |
let description = "Unexpected result type: \(String(describing: request.results))" | |
continuation.resume(throwing: NSError(domain: "ImagePredictor", code: 2, userInfo: [NSLocalizedDescriptionKey: description])) | |
} | |
} | |
request.imageCropAndScaleOption = .centerCrop | |
do { | |
try handler.perform([request]) | |
} catch { | |
continuation.resume(throwing: error) | |
} | |
} | |
} | |
private func handleCLIPPrediction(from request: VNRequest, name: String) -> CardStats? { | |
guard let result = request.results?.first as? VNCoreMLFeatureValueObservation, | |
let imageEmbedding = result.featureValue.multiArrayValue else { | |
return nil | |
} | |
let floatArray: [Float] = (0..<imageEmbedding.count).map { | |
Float(truncating: imageEmbedding[$0]) | |
} | |
let norm = sqrt(floatArray.reduce(0) { $0 + $1 * $1 }) | |
let normalized = floatArray.map { $0 / norm } | |
let stats = predictStats(from: normalized) | |
let title = predictTitle(from: normalized, name) | |
let type = predictType(from: normalized) | |
let rarity = checkRarity(from: normalized) | |
return CardStats(title: title, type: type, stats: stats, rarity: rarity) | |
} | |
private func predictStats(from imageEmbedding: [Float]) -> [Stat] { | |
guard !statsEmbeddings.isEmpty else { return [] } | |
var rankedStats = statsEmbeddings | |
.map { ($0.label, dot(imageEmbedding, $0.embedding)) } | |
.sorted(by: { $0.1 > $1.1 }) | |
let attackIndex = rankedStats.firstIndex(where: { $0.0 == "Attack" }) ?? 0 | |
let attack = rankedStats.remove(at: attackIndex) | |
let defenceIndex = rankedStats.firstIndex(where: { $0.0 == "Defence" }) ?? 0 | |
let defence = rankedStats.remove(at: defenceIndex) | |
let magicIndex = rankedStats.firstIndex(where: { $0.0 == "Magic" }) ?? 0 | |
let magic = rankedStats.remove(at: magicIndex) | |
let topThreeStats = rankedStats.prefix(3).map { ($0.0, $0.1 * 0.66) } | |
let stats = [attack, defence, magic] + topThreeStats | |
return stats.map { label, confidence in | |
let roundedValue = (Double(abs(confidence)) * 120_000).rounded(to: 100) | |
return Stat(name: label, value: roundedValue) | |
} | |
} | |
private func predictTitle(from imageEmbedding: [Float], _ name: String) -> String { | |
guard !titleEmbeddings.isEmpty else { return "" } | |
let topTitle = titleEmbeddings | |
.map { ($0.fullNameCopy.replacingOccurrences(of: "X", with: name), dot(imageEmbedding, $0.embedding)) } | |
.sorted(by: { $0.1 > $1.1 }) | |
return topTitle.first?.0 ?? "" | |
} | |
private func predictType(from imageEmbedding: [Float]) -> ElementalType { | |
guard !typeEmbeddings.isEmpty else { return .init(emoji: "", hexClor: "") } | |
// 1 in 4 chance to return a random type | |
if Int.random(in: 0..<4) == 0 { | |
let randomType = typeEmbeddings.randomElement()! | |
return ElementalType(emoji: randomType.emoji, hexClor: randomType.hex) | |
} | |
let scoredTypes = typeEmbeddings.map { type -> (TypeEmbedding, Float) in | |
var score = dot(imageEmbedding, type.embedding) | |
if type.name == "Arcane" { | |
score -= 0.45 | |
} | |
return (type, score) | |
} | |
let topType = scoredTypes | |
.sorted(by: { $0.1 > $1.1 }) | |
.first?.0 | |
guard let topTypeItem = topType else { | |
return .init(emoji: "", hexClor: "") | |
} | |
return ElementalType(emoji: topTypeItem.emoji, hexClor: topTypeItem.hex) | |
} | |
private func dot(_ a: [Float], _ b: [Float]) -> Float { | |
zip(a, b).reduce(0) { $0 + $1.0 * $1.1 } | |
} | |
private func checkRarity(from imageEmbedding: [Float]) -> Rarity { | |
guard let first = imageEmbedding.first else { | |
return .common | |
} | |
let decimalString = String(format: "%.12f", first) | |
if let lastDigitChar = decimalString.last, | |
let digit = Int(String(lastDigitChar)) { | |
switch digit { | |
case 9: | |
return .secretRare | |
case 7...8: | |
return .ultraRare | |
case 4...6: | |
return .rare | |
default: | |
return .common | |
} | |
} else { | |
return .common | |
} | |
} | |
} | |
private extension Double { | |
func rounded(to nearest: Double) -> Double { | |
(self / nearest).rounded() * nearest | |
} | |
} | |
private extension CGImagePropertyOrientation { | |
/// Converts an image orientation to a Core Graphics image property orientation. | |
/// - Parameter orientation: A `UIImage.Orientation` instance. | |
/// | |
/// The two orientation types use different raw values. | |
init(_ orientation: UIImage.Orientation) { | |
switch orientation { | |
case .up: self = .up | |
case .down: self = .down | |
case .left: self = .left | |
case .right: self = .right | |
case .upMirrored: self = .upMirrored | |
case .downMirrored: self = .downMirrored | |
case .leftMirrored: self = .leftMirrored | |
case .rightMirrored: self = .rightMirrored | |
@unknown default: self = .up | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment