Skip to content

Commit

Permalink
Merge pull request #19 from saramaxyz/release/0.0.3
Browse files Browse the repository at this point in the history
Release/0.0.3
  • Loading branch information
yllfejziu authored Oct 4, 2023
2 parents 67bfa62 + 0fc085a commit a0f50ff
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"pins" : [
{
"identity" : "zipfoundation",
"kind" : "remoteSourceControl",
"location" : "https://github.com/weichsel/ZIPFoundation.git",
"state" : {
"revision" : "a3f5c2bae0f04b0bce9ef3c4ba6bd1031a0564c4",
"version" : "0.9.17"
}
}
],
"version" : 2
}
6 changes: 5 additions & 1 deletion ExampleApp/ExampleApp/MLModelInfo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
//

import Foundation
import AeroEdge

struct MLModelInfo: Equatable {
let name: String
let modelType: ModelType
let bundledURL: URL?

static let yolo = MLModelInfo(name: "Yolo", bundledURL: Bundle.main.url(forResource: "YOLOv3Int8LUT", withExtension: ".mlmodelc"))
static let yolo = MLModelInfo(name: "Yolo",
modelType: .mlModel,
bundledURL: Bundle.main.url(forResource: "YOLOv3Int8LUT", withExtension: ".mlmodelc"))
}
9 changes: 7 additions & 2 deletions ExampleApp/ExampleApp/ViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ViewModel: ObservableObject {

func getYoloModel() async {
await aeroEdge.getModel(modelName: MLModelInfo.yolo.name,
modelType: MLModelInfo.yolo.modelType,
bundledModelURL: MLModelInfo.yolo.bundledURL) { progress in
print("Yolo Progress: \(progress)")
self.downloadProgress = progress
Expand All @@ -31,10 +32,14 @@ class ViewModel: ObservableObject {
switch result {
case .success(let model):
print(model.modelDescription)
self.modelDescription = model.description
DispatchQueue.main.async {
self.modelDescription = model.description
}
case .failure(let error):
print(error.localizedDescription)
self.modelDescription = error.localizedDescription
DispatchQueue.main.async {
self.modelDescription = error.localizedDescription
}
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"pins" : [
{
"identity" : "zipfoundation",
"kind" : "remoteSourceControl",
"location" : "https://github.com/weichsel/ZIPFoundation.git",
"state" : {
"revision" : "a3f5c2bae0f04b0bce9ef3c4ba6bd1031a0564c4",
"version" : "0.9.17"
}
}
],
"version" : 2
}
10 changes: 8 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// swift-tools-version: 5.9
// swift-tools-version: 5.8
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription
Expand All @@ -12,11 +12,17 @@ let package = Package(
name: "AeroEdge",
targets: ["AeroEdge"]),
],
dependencies: [
//.package(url: "https://github.com/marmelroy/Zip.git", .upToNextMinor(from: "2.1.0")),
.package(url: "https://github.com/weichsel/ZIPFoundation.git", .upToNextMajor(from: "0.9.0"))
],
targets: [
// Targets are the basic building blocks of a package, defining a module or a test suite.
// Targets can depend on other targets in this package and products from dependencies.
.target(
name: "AeroEdge"),
name: "AeroEdge",
dependencies: ["ZIPFoundation"]
),
.testTarget(
name: "AeroEdgeTests",
dependencies: ["AeroEdge"]),
Expand Down
56 changes: 43 additions & 13 deletions Sources/AeroEdge/AeroEdge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import CoreML
import ZIPFoundation

public class AeroEdge: NSObject {
public static let backgroundIdentifier = "com.app.backgroundModelDownload"
Expand All @@ -31,6 +32,7 @@ public class AeroEdge: NSObject {

public func getModel(
modelName: String,
modelType: ModelType = .mlModel,
bundledModelURL: URL?,
progress: ((Float) -> Void)?,
completion: @escaping (Result<MLModel, Error>, Bool) -> Void
Expand All @@ -41,10 +43,13 @@ public class AeroEdge: NSObject {
let remoteVersion = modelInfo.version

// Step 2: Check local model version
if modelChecker.checkLocalModelVersion(modelName: modelName, remoteVersion: remoteVersion),
let localVersion = self.localModelStore.getLocalModelVersion(for: modelName) {
if modelChecker.checkLocalModelVersion(modelName: modelName,
remoteVersion: remoteVersion,
fileExtension: modelType.rawValue),
let localVersion = self.localModelStore.getLocalModelVersion(for: modelName,
fileExtension: modelType.rawValue) {
// Load local model and return
let localModel = try await self.loadLocalModel(modelName: modelName, version: localVersion)
let localModel = try await self.loadLocalModel(modelName: modelName, version: localVersion, fileExtension: modelType.rawValue)
completion(.success(localModel), true) // true indicates that this is the final model and no newer version is available
} else {
// Step 3: If there's a local version, return it first
Expand All @@ -54,10 +59,11 @@ public class AeroEdge: NSObject {

// Now download the newer version from the server
if let progressClosure = progress {
self.downloadProgressClosures["\(modelName)_\(remoteVersion).mlmodel"] = progressClosure
self.downloadProgressClosures["\(modelName)_\(remoteVersion).\(modelType.rawValue)"] = progressClosure
}
do {
let newModel = try await self.downloadAndLoadModel(modelName: modelName,
fileExtension: modelType.rawValue,
remoteVersion: remoteVersion,
remoteModelURL: modelInfo.url,
bundledModelURL: bundledModelURL)
Expand All @@ -69,7 +75,9 @@ public class AeroEdge: NSObject {
} catch {
// Handle error - try loading local or bundled version if available
do {
let fallbackModel = try await loadLocalOrBundledModel(modelName: modelName, bundledModelURL: bundledModelURL)
let fallbackModel = try await loadLocalOrBundledModel(modelName: modelName,
bundledModelURL: bundledModelURL,
fileExtension: modelType.rawValue)
completion(.success(fallbackModel), true) // true indicates that this is the final model
} catch {
completion(.failure(error), true) // true indicates that this is the final callback
Expand All @@ -79,15 +87,20 @@ public class AeroEdge: NSObject {
}

private extension AeroEdge {
func loadLocalModel(modelName: String, version: Int) async throws -> MLModel {
func loadLocalModel(modelName: String, version: Int, fileExtension: String = "mlmodel") async throws -> MLModel {
// Get the URL of the local model using the `ModelLocalStore` instance
guard let modelURL = localModelStore.getLocalModelURL(for: modelName, version: version) else {
guard let modelURL = localModelStore.getLocalModelURL(for: modelName,
version: version,
fileExtension: fileExtension) else {
throw ModelError.modelNotFound
}

// Try to load the MLModel from the obtained URL
do {
let modelEntity = ModelEntity(name: modelName, version: version, url: modelURL)
let modelEntity = ModelEntity(name: modelName,
version: version,
url: modelURL,
fileExtension: fileExtension)
let model = try await modelCompiler.compileModel(model: modelEntity, from: modelURL)
return model
} catch {
Expand All @@ -96,16 +109,30 @@ private extension AeroEdge {
}
}

func downloadAndLoadModel(modelName: String, remoteVersion: Int, remoteModelURL: URL, bundledModelURL: URL?) async throws -> MLModel {
func downloadAndLoadModel(modelName: String,
fileExtension: String = "mlmodel",
remoteVersion: Int,
remoteModelURL: URL,
bundledModelURL: URL?) async throws -> MLModel {
do {
// Create a ModelEntity instance to represent the model
let modelEntity = ModelEntity(name: modelName, version: remoteVersion, url: remoteModelURL)
let modelEntity = ModelEntity(name: modelName,
version: remoteVersion,
url: remoteModelURL,
fileExtension: fileExtension)

// Download the model using the ModelDownloader
let downloadURL = try await modelDownloader.downloadModelAsync(modelEntity)

// TODO: HANDLE UNZIP HERE FIRST
try FileManager.default.unzipItem(at: downloadURL, to: downloadURL.deletingLastPathComponent())

// TODO: Delete zip file
try FileManager.default.removeItem(at: downloadURL)

// Load the newly downloaded model into memory
let mlModel = try await modelCompiler.compileModel(model: modelEntity, from: downloadURL)
let mlModel = try await modelCompiler.compileModel(model: modelEntity,
from: downloadURL.deletingPathExtension())

return mlModel
} catch {
Expand All @@ -120,8 +147,11 @@ private extension AeroEdge {
}
}

func loadLocalOrBundledModel(modelName: String, bundledModelURL: URL?) async throws -> MLModel {
if let latestLocalVersion = self.localModelStore.getLocalModelVersion(for: modelName) {
func loadLocalOrBundledModel(modelName: String,
bundledModelURL: URL?,
fileExtension: String = "mlmodel") async throws -> MLModel {
if let latestLocalVersion = self.localModelStore.getLocalModelVersion(for: modelName,
fileExtension: fileExtension) {
// If a local version is available, load it
return try await self.loadLocalModel(modelName: modelName, version: latestLocalVersion)
} else if let bundledModelURL = bundledModelURL {
Expand Down
9 changes: 9 additions & 0 deletions Sources/AeroEdge/Domain/Entities/ModelEntity.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ public struct ModelEntity {
public let name: String
public let version: Int
public let url: URL
public var fileExtension: String = "mlmodel"

public var versionedName: String {
"\(name)_\(version)"
}

public var versionedNameWithExtension: String {
"\(versionedName).\(fileExtension)"
}

public var versionedNameWithExtensionZipped: String {
"\(versionedName).\(fileExtension).zip"
}
}
13 changes: 13 additions & 0 deletions Sources/AeroEdge/Domain/Entities/ModelType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//
// ModelType.swift
//
//
// Created by Yll Fejziu on 04/10/2023.
//

import Foundation

public enum ModelType: String, Equatable {
case mlPackage = "mlpackage"
case mlModel = "mlmodel"
}
4 changes: 3 additions & 1 deletion Sources/AeroEdge/Domain/UseCases/ModelCheckerUseCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
import Foundation

protocol ModelCheckerUseCase {
func checkLocalModelVersion(modelName: String, remoteVersion: Int) -> Bool
func checkLocalModelVersion(modelName: String,
remoteVersion: Int,
fileExtension: String) -> Bool
}
4 changes: 2 additions & 2 deletions Sources/AeroEdge/Domain/UseCases/ModelStorable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import Foundation

public protocol ModelStorable {
func getLocalModelVersion(for modelName: String) -> Int?
func getLocalModelURL(for modelName: String, version: Int) -> URL?
func getLocalModelVersion(for modelName: String, fileExtension: String) -> Int?
func getLocalModelURL(for modelName: String, version: Int, fileExtension: String) -> URL?
func saveLocalModel(_ model: ModelEntity, url: URL)
func deleteOldVersions(of model: ModelEntity)
}
10 changes: 8 additions & 2 deletions Sources/AeroEdge/Infrastructure/AeroEdgeModelServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,19 @@ public class AeroEdgeModelServer: ModelServer {
guard let apiResponse = try? JSONDecoder().decode(ModelInfo.self, from: data) else {
throw ModelError.failedToLoadModel("Failed to parse response")
}

return ModelEntity(name: apiResponse.name, version: apiResponse.version, url: apiResponse.signed_url)
return ModelEntity(name: apiResponse.name,
version: apiResponse.version,
url: apiResponse.signed_url,
fileExtension: apiResponse.fileExtension)
}

struct ModelInfo: Decodable {
let name: String
let signed_url: URL
let version: Int

var fileExtension: String {
signed_url.lastPathComponent.split(separator: ".").dropFirst().joined(separator: ".")
}
}
}
6 changes: 4 additions & 2 deletions Sources/AeroEdge/Infrastructure/ModelChecker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class ModelChecker: ModelCheckerUseCase {
self.localModelStore = localModelStore
}

func checkLocalModelVersion(modelName: String, remoteVersion: Int) -> Bool {
func checkLocalModelVersion(modelName: String,
remoteVersion: Int,
fileExtension: String = "mlmodel") -> Bool {
// Get the version of the local model
if let localVersion = localModelStore.getLocalModelVersion(for: modelName) {
if let localVersion = localModelStore.getLocalModelVersion(for: modelName, fileExtension: fileExtension) {
return localVersion >= remoteVersion
}

Expand Down
2 changes: 1 addition & 1 deletion Sources/AeroEdge/Infrastructure/ModelCompiler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ModelCompiler: ModelCompilerUseCase {
in: .userDomainMask,
appropriateFor: nil,
create: true)
let compiledModelURL = applicationSupportDirectoryURL.appendingPathComponent("\(model.versionedName).mlmodelc")
let compiledModelURL = applicationSupportDirectoryURL.appendingPathComponent("\(model.versionedNameWithExtension)")

// Check if the compiled model already exists, if yes, then return it
if fileManager.fileExists(atPath: compiledModelURL.path) {
Expand Down
6 changes: 4 additions & 2 deletions Sources/AeroEdge/Infrastructure/ModelDownloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ extension ModelDownloader: URLSessionDownloadDelegate {
}

do {
let destinationURL = modelStore.getLocalModelURL(for: model.name, version: model.version) ?? createDestinationURL(for: model)
let destinationURL = modelStore.getLocalModelURL(for: model.name,
version: model.version,
fileExtension: model.fileExtension) ?? createDestinationURL(for: model)

// Check if directory exists, if not create it
let directoryURL = destinationURL.deletingLastPathComponent()
Expand Down Expand Up @@ -111,7 +113,7 @@ private extension ModelDownloader {

func createDestinationURL(for model: ModelEntity) -> URL {
let documentsDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
return documentsDirectory.appendingPathComponent(model.versionedName + ".mlmodel")
return documentsDirectory.appendingPathComponent(model.versionedNameWithExtensionZipped)
}

func deleteOldVersions(of model: ModelEntity) {
Expand Down
11 changes: 6 additions & 5 deletions Sources/AeroEdge/Infrastructure/ModelLocalStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public class ModelLocalStore: ModelStorable {
self.fileManager = fileManager
}

public func getLocalModelVersion(for modelName: String) -> Int? {
public func getLocalModelVersion(for modelName: String,
fileExtension: String = "mlmodel") -> Int? {
let documentsDirectory = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first!

do {
Expand All @@ -27,7 +28,7 @@ public class ModelLocalStore: ModelStorable {
let versions = modelFiles.compactMap { url -> Int? in
let fileName = url.lastPathComponent
guard let startRange = fileName.range(of: "\(modelName)_"),
let endRange = fileName.range(of: ".mlmodel") else { return nil }
let endRange = fileName.range(of: ".\(fileExtension)") else { return nil }

// Extract the version number using the range between the modelName_ and .mlmodel
let versionString = fileName[startRange.upperBound..<endRange.lowerBound]
Expand All @@ -43,8 +44,8 @@ public class ModelLocalStore: ModelStorable {
}
}

public func getLocalModelURL(for modelName: String, version: Int) -> URL? {
let modelNameWithVersion = "\(modelName)_\(version).mlmodel"
public func getLocalModelURL(for modelName: String, version: Int, fileExtension: String) -> URL? {
let modelNameWithVersion = "\(modelName)_\(version).\(fileExtension)"
let documentsDirectory = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first!
let fileURL = documentsDirectory.appendingPathComponent(modelNameWithVersion)

Expand All @@ -53,7 +54,7 @@ public class ModelLocalStore: ModelStorable {

public func saveLocalModel(_ model: ModelEntity, url: URL) {
let documentsDirectory = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first!
let destinationURL = documentsDirectory.appendingPathComponent(model.versionedName + ".mlmodel")
let destinationURL = documentsDirectory.appendingPathComponent(model.versionedNameWithExtensionZipped)

do {
if fileManager.fileExists(atPath: destinationURL.path) {
Expand Down

0 comments on commit a0f50ff

Please sign in to comment.