Merge pull request #3744 from Yogayu/master
android / android_build (push) Has been cancelled Details
linux / linux_buil_test (push) Has been cancelled Details
macos / macos_buil_test (push) Has been cancelled Details
pymnn-linux / pymnn_linux_buil_test (push) Has been cancelled Details
pymnn-macos / pymnn_macos_buil_test (push) Has been cancelled Details
pymnn-windows / pymnn_windows_buil_test (push) Has been cancelled Details
windows / windows_build_test (push) Has been cancelled Details

New Version of iOS MNN Chat
This commit is contained in:
jxt1234 2025-07-22 14:17:19 +08:00 committed by GitHub
commit e814142254
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
110 changed files with 10801 additions and 1224 deletions

2
.gitignore vendored
View File

@ -376,3 +376,5 @@ datasets/*
# qnn 3rdParty
source/backend/qnn/3rdParty/include
apps/iOS/MNNLLMChat/Chat
apps/iOS/MNNLLMChat/swift-transformers

View File

@ -52,7 +52,7 @@
isa = PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet;
buildPhase = 3E8591FA2D1D45070067B46F /* Sources */;
membershipExceptions = (
LLMWrapper/DiffusionSession.h,
InferenceEngine/DiffusionSession.h,
);
};
/* End PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */

View File

@ -0,0 +1,21 @@
{
"images" : [
{
"filename" : "star.png",
"idiom" : "universal",
"scale" : "1x"
},
{
"idiom" : "universal",
"scale" : "2x"
},
{
"idiom" : "universal",
"scale" : "3x"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

View File

@ -63,7 +63,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
}
Task {
var status: Message.Status = .sending
let status: Message.Status = .sending
var sender: LLMChatUser
switch userType {
@ -74,13 +74,15 @@ final class LLMChatInteractor: ChatInteractorProtocol {
case .system:
sender = chatData.system
}
var message: LLMChatMessage = await draftMessage.toLLMChatMessage(
let message: LLMChatMessage = await draftMessage.toLLMChatMessage(
id: UUID().uuidString,
user: sender,
status: status)
DispatchQueue.main.async { [weak self] in
// PerformanceMonitor.shared.recordUIUpdate()
switch userType {
case .user, .system:
self?.chatState.value.append(message)
@ -97,18 +99,24 @@ final class LLMChatInteractor: ChatInteractorProtocol {
case .assistant:
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
if let isDeepSeek = self?.modelInfo.name.lowercased().contains("deepseek"), isDeepSeek == true,
let text = self?.processor.process(progress: message.text) {
updateLastMsg?.text = text
} else {
updateLastMsg?.text += message.text
}
message.text = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1].text ?? ""
self?.chatState.value[(self?.chatState.value.count ?? 1) - 1] = updateLastMsg ?? message
// PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
let text = self?.processor.process(progress: message.text) {
updateLastMsg?.text = text
} else {
if let currentText = updateLastMsg?.text {
updateLastMsg?.text = currentText + message.text
} else {
updateLastMsg?.text = message.text
}
}
if let updatedMsg = updateLastMsg {
self?.chatState.value[(self?.chatState.value.count ?? 1) - 1] = updatedMsg
}
// }
}
}
}

View File

@ -25,13 +25,13 @@ final class LLMChatData {
self.assistant = LLMChatUser(
uid: "2",
name: modelInfo.name,
name: modelInfo.modelName,
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
)
self.system = LLMChatUser(
uid: "0",
name: modelInfo.name,
name: modelInfo.modelName,
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
)
}

View File

@ -19,7 +19,7 @@ actor LLMState {
return isProcessing
}
func processContent(_ content: String, llm: LLMInferenceEngineWrapper?, completion: @escaping (String) -> Void) {
llm?.processInput(content, withOutput: completion)
func processContent(_ content: String, llm: LLMInferenceEngineWrapper?, showPerformance: Bool, completion: @escaping (String) -> Void) {
llm?.processInput(content, withOutput: completion, showPerformance: true)
}
}

View File

@ -0,0 +1,121 @@
//
// PerformanceMonitor.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Foundation
import UIKit
/**
* PerformanceMonitor - A singleton utility for monitoring and measuring UI performance
*
* This class provides real-time performance monitoring capabilities to help identify
* UI update bottlenecks, frame drops, and slow operations in iOS applications.
* It's particularly useful during development to ensure smooth user experience.
*
* Key Features:
* - Real-time FPS monitoring and frame drop detection
* - UI update lag detection with customizable thresholds
* - Execution time measurement for specific operations
* - Automatic performance statistics reporting
* - Thread-safe singleton implementation
*
* Usage Examples:
*
* 1. Monitor UI Updates:
* ```swift
* // Call this in your UI update methods
* PerformanceMonitor.shared.recordUIUpdate()
* ```
*
* 2. Measure Operation Performance:
* ```swift
* let result = PerformanceMonitor.shared.measureExecutionTime(operation: "Data Processing") {
* // Your expensive operation here
* return processLargeDataSet()
* }
* ```
*
* 3. Integration in ViewModels:
* ```swift
* func updateUI() {
* PerformanceMonitor.shared.recordUIUpdate()
* // Your UI update code
* }
* ```
*
* Performance Thresholds:
* - Target FPS: 60 FPS
* - Frame threshold: 25ms (1.5x normal frame time)
* - Slow operation threshold: 16ms (1 frame time)
*/
class PerformanceMonitor {
static let shared = PerformanceMonitor()
private var lastUpdateTime: Date = Date()
private var updateCount: Int = 0
private var frameDropCount: Int = 0
private let targetFPS: Double = 60.0
private let frameThreshold: TimeInterval = 1.0 / 60.0 * 1.5 // Allow 1.5x normal frame time
private init() {}
/**
* Records a UI update event and monitors performance metrics
*
* Call this method whenever you perform UI updates to track performance.
* It automatically detects frame drops and calculates FPS statistics.
* Performance statistics are logged every second.
*/
func recordUIUpdate() {
let currentTime = Date()
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
updateCount += 1
// Detect frame drops
if timeDiff > frameThreshold {
frameDropCount += 1
print("⚠️ UI Update Lag detected: \(timeDiff * 1000)ms (expected: \(frameThreshold * 1000)ms)")
}
// Report statistics every second
if timeDiff >= 1.0 {
let actualFPS = Double(updateCount) / timeDiff
let dropRate = Double(frameDropCount) / Double(updateCount) * 100
print("📊 Performance Stats - FPS: \(String(format: "%.1f", actualFPS)), Drop Rate: \(String(format: "%.1f", dropRate))%")
// Reset counters for next measurement cycle
updateCount = 0
frameDropCount = 0
lastUpdateTime = currentTime
}
}
/**
* Measures execution time for a specific operation
*
* Wraps any operation and measures its execution time. Operations taking
* longer than 16ms (1 frame time) are logged as slow operations.
*
* - Parameters:
* - operation: A descriptive name for the operation being measured
* - block: The operation to measure
* - Returns: The result of the operation
* - Throws: Re-throws any error thrown by the operation
*/
func measureExecutionTime<T>(operation: String, block: () throws -> T) rethrows -> T {
let startTime = CFAbsoluteTimeGetCurrent()
let result = try block()
let executionTime = CFAbsoluteTimeGetCurrent() - startTime
if executionTime > 0.016 { // Over 16ms (1 frame time)
print("⏱️ Slow Operation: \(operation) took \(String(format: "%.3f", executionTime * 1000))ms")
}
return result
}
}

View File

@ -0,0 +1,214 @@
# UI Performance Optimization Guide
## Overview
This document explains the performance optimization utilities implemented in the chat system to ensure smooth AI streaming text output and overall UI responsiveness.
## Core Components
### 1. PerformanceMonitor
A singleton utility for real-time performance monitoring and measurement.
#### Implementation Principles
- **Real-time FPS Monitoring**: Tracks UI update frequency and calculates actual FPS
- **Frame Drop Detection**: Identifies when UI updates exceed the 16.67ms threshold (60 FPS)
- **Operation Time Measurement**: Measures execution time of specific operations
- **Automatic Statistics Reporting**: Logs performance metrics every second
#### Key Features
```swift
class PerformanceMonitor {
static let shared = PerformanceMonitor()
// Performance thresholds
private let targetFPS: Double = 60.0
private let frameThreshold: TimeInterval = 1.0 / 60.0 * 1.5 // 25ms threshold
func recordUIUpdate() { /* Track UI updates */ }
func measureExecutionTime<T>(operation: String, block: () throws -> T) { /* Measure operations */ }
}
```
#### Usage in the Chat System
- Integrated into `LLMChatInteractor` to monitor message updates
- Tracks UI update frequency during AI text streaming
- Identifies performance bottlenecks in real-time
### 2. UIUpdateOptimizer
An actor-based utility for batching and throttling UI updates during streaming scenarios.
#### Implementation Principles
- **Batching Mechanism**: Groups multiple small updates into larger, more efficient ones
- **Time-based Throttling**: Limits update frequency to prevent UI overload
- **Actor-based Thread Safety**: Ensures safe concurrent access to update queue
- **Automatic Flush Strategy**: Intelligently decides when to apply batched updates
#### Architecture
```swift
actor UIUpdateOptimizer {
static let shared = UIUpdateOptimizer()
private var pendingUpdates: [String] = []
private let batchSize: Int = 5 // Batch threshold
private let flushInterval: TimeInterval = 0.03 // 30ms throttling
func addUpdate(_ content: String, completion: @escaping (String) -> Void)
func forceFlush(completion: @escaping (String) -> Void)
}
```
#### Optimization Strategies
1. **Batch Size Control**: Groups up to 5 updates before flushing
2. **Time-based Throttling**: Flushes updates every 30ms maximum
3. **Intelligent Scheduling**: Cancels redundant flush operations
4. **Main Thread Delegation**: Ensures UI updates occur on the main thread
#### Integration Points
- **LLM Streaming**: Optimizes real-time text output from AI models
- **Message Updates**: Batches frequent message content changes
- **Force Flush**: Ensures final content is displayed when streaming ends
## Performance Optimization Flow
```
AI Model Output → UIUpdateOptimizer → Batched Updates → UI Thread → Display
PerformanceMonitor (Monitoring)
Console Logs (Metrics)
```
## Testing and Validation
### Performance Metrics
1. **Target Performance**:
- Maintain 50+ FPS during streaming
- Keep frame drop rate below 5%
- Single operations under 16ms
2. **Monitoring Indicators**:
- `📊 Performance Stats` - Real-time FPS and drop rate
- `⚠️ UI Update Lag detected` - Frame drop warnings
- `⏱️ Slow Operation` - Operation time alerts
### Testing Methodology
1. **Streaming Tests**:
- Test with long-form AI responses (articles, code)
- Monitor console output for performance warnings
- Observe visual smoothness of text animation
2. **Load Testing**:
- Rapid successive message sending
- Large text blocks processing
- Multiple concurrent operations
3. **Comparative Analysis**:
- Before/after optimization measurements
- Different device performance profiles
- Various content types and sizes
### Debug Configuration
For development and testing purposes:
```swift
// Example configuration adjustments (not implemented in production)
// UIUpdateOptimizer.shared.batchSize = 10
// UIUpdateOptimizer.shared.flushInterval = 0.05
```
## Implementation Details
### UIUpdateOptimizer Algorithm
1. **Add Update**: New content is appended to pending queue
2. **Threshold Check**: Evaluate if immediate flush is needed
- Batch size reached (≥5 updates)
- Time threshold exceeded (≥30ms since last flush)
3. **Scheduling**: If not immediate, schedule delayed flush
4. **Flush Execution**: Combine all pending updates and execute on main thread
5. **Cleanup**: Clear queue and reset timing
### PerformanceMonitor Algorithm
1. **Update Recording**: Track each UI update call
2. **Timing Analysis**: Calculate time difference between updates
3. **Frame Drop Detection**: Compare against 25ms threshold
4. **Statistics Calculation**: Compute FPS and drop rate every second
5. **Logging**: Output performance metrics to console
## Integration Examples
### In ViewModels
```swift
func updateUI() {
PerformanceMonitor.shared.recordUIUpdate()
// UI update code here
}
let result = PerformanceMonitor.shared.measureExecutionTime(operation: "Data Processing") {
return processLargeDataSet()
}
```
### In Streaming Scenarios
```swift
await UIUpdateOptimizer.shared.addUpdate(newText) { batchedContent in
// Update UI with optimized batched content
updateTextView(with: batchedContent)
}
// When stream ends
await UIUpdateOptimizer.shared.forceFlush { finalContent in
finalizeTextDisplay(with: finalContent)
}
```
## Troubleshooting
### Common Performance Issues
1. **High Frame Drop Rate**:
- Check for blocking operations on main thread
- Verify batch size configuration
- Monitor memory usage
2. **Slow Operation Warnings**:
- Profile specific operations causing delays
- Consider background threading for heavy tasks
- Optimize data processing algorithms
3. **Inconsistent Performance**:
- Check device thermal state
- Monitor memory pressure
- Verify background app activity
### Diagnostic Tools
- **Console Monitoring**: Watch for performance log messages
- **Xcode Instruments**: Use Time Profiler for detailed analysis
- **Memory Graph**: Check for memory leaks affecting performance
- **Energy Impact**: Monitor battery and thermal effects
## Best Practices
1. **Proactive Monitoring**: Always call `recordUIUpdate()` for critical UI operations
2. **Batch When Possible**: Use `UIUpdateOptimizer` for frequent updates
3. **Measure Critical Paths**: Wrap expensive operations with `measureExecutionTime`
4. **Test on Real Devices**: Performance varies significantly across device types
5. **Monitor in Production**: Keep performance logging enabled during development
This performance optimization system ensures smooth user experience during AI text generation while providing developers with the tools needed to maintain and improve performance over time.

View File

@ -0,0 +1,136 @@
//
// UIUpdateOptimizer.swift
// MNNLLMiOS
//
// Created by () on 2025/7/7.
//
import Foundation
import SwiftUI
/**
* UIUpdateOptimizer - A utility for batching and throttling UI updates to improve performance
*
* This actor-based optimizer helps reduce the frequency of UI updates by batching multiple
* updates together and applying throttling mechanisms. It's particularly useful for scenarios
* like streaming text updates, real-time data feeds, or any situation where frequent UI
* updates might cause performance issues.
*
* Key Features:
* - Batches multiple updates into a single operation
* - Applies time-based throttling to limit update frequency
* - Thread-safe actor implementation
* - Automatic flush mechanism for pending updates
*
* Usage Example:
* ```swift
* // For streaming text updates
* await UIUpdateOptimizer.shared.addUpdate(newText) { batchedContent in
* // Update UI with batched content
* textView.text = batchedContent
* }
*
* // Force flush remaining updates when stream ends
* await UIUpdateOptimizer.shared.forceFlush { finalContent in
* textView.text = finalContent
* }
* ```
*
* Configuration:
* - batchSize: Number of updates to batch before triggering immediate flush (default: 5)
* - flushInterval: Time interval in seconds between automatic flushes (default: 0.03s / 30ms)
*/
actor UIUpdateOptimizer {
static let shared = UIUpdateOptimizer()
private var pendingUpdates: [String] = []
private var lastFlushTime: Date = Date()
private var flushTask: Task<Void, Never>?
// Configuration constants
private let batchSize: Int = 5 // Batch size threshold for immediate flush
private let flushInterval: TimeInterval = 0.03 // 30ms throttling interval
private init() {}
/**
* Adds a content update to the pending queue
*
* Updates are either flushed immediately if batch size or time threshold is reached,
* or scheduled for delayed flushing to optimize performance.
*
* - Parameters:
* - content: The content string to add to the update queue
* - completion: Callback executed with the batched content when flushed
*/
func addUpdate(_ content: String, completion: @escaping (String) -> Void) {
pendingUpdates.append(content)
// Determine if immediate flush is needed based on batch size or time interval
let shouldFlushImmediately = pendingUpdates.count >= batchSize ||
Date().timeIntervalSince(lastFlushTime) >= flushInterval
if shouldFlushImmediately {
flushUpdates(completion: completion)
} else {
// Schedule delayed flush to optimize performance
scheduleFlush(completion: completion)
}
}
/**
* Schedules a delayed flush operation
*
* Cancels any existing scheduled flush and creates a new one to avoid
* excessive flush operations while maintaining responsiveness.
*
* - Parameter completion: Callback to execute when flush occurs
*/
private func scheduleFlush(completion: @escaping (String) -> Void) {
// Cancel previous scheduled flush to avoid redundant operations
flushTask?.cancel()
flushTask = Task {
try? await Task.sleep(nanoseconds: UInt64(flushInterval * 1_000_000_000))
if !Task.isCancelled && !pendingUpdates.isEmpty {
flushUpdates(completion: completion)
}
}
}
/**
* Flushes all pending updates immediately
*
* Combines all pending updates into a single string and executes the completion
* callback on the main actor thread for UI updates.
*
* - Parameter completion: Callback executed with the combined content
*/
private func flushUpdates(completion: @escaping (String) -> Void) {
guard !pendingUpdates.isEmpty else { return }
let batchedContent = pendingUpdates.joined()
pendingUpdates.removeAll()
lastFlushTime = Date()
Task { @MainActor in
completion(batchedContent)
}
}
/**
* Forces immediate flush of any remaining pending updates
*
* This method should be called when you need to ensure all pending updates
* are processed immediately, such as when a stream ends or the view is about
* to disappear.
*
* - Parameter completion: Callback executed with any remaining content
*/
func forceFlush(completion: @escaping (String) -> Void) {
if !pendingUpdates.isEmpty {
flushUpdates(completion: completion)
}
}
}

View File

@ -11,7 +11,6 @@ import AVFoundation
import ExyteChat
import ExyteMediaPicker
final class LLMChatViewModel: ObservableObject {
private var llm: LLMInferenceEngineWrapper?
@ -21,6 +20,7 @@ final class LLMChatViewModel: ObservableObject {
@Published var messages: [Message] = []
@Published var isModelLoaded = false
@Published var isProcessing: Bool = false
@Published var currentStreamingMessageId: String? = nil
@Published var useMmap: Bool = false
@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
let modelConfigManager: ModelConfigManager
var isDiffusionModel: Bool {
return modelInfo.name.lowercased().contains("diffusion")
return modelInfo.modelName.lowercased().contains("diffusion")
}
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
), userType: .system)
}
if modelInfo.name.lowercased().contains("diffusion") {
if modelInfo.modelName.lowercased().contains("diffusion") {
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
Task { @MainActor in
print("Diffusion Model \(success)")
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
func sendToLLM(draft: DraftMessage) {
self.send(draft: draft, userType: .user)
if isModelLoaded {
if modelInfo.name.lowercased().contains("diffusion") {
if modelInfo.modelName.lowercased().contains("diffusion") {
self.getDiffusionResponse(draft: draft)
} else {
self.getLLMRespsonse(draft: draft)
@ -166,9 +166,7 @@ final class LLMChatViewModel: ObservableObject {
Task {
let tempDir = FileManager.default.temporaryDirectory
let imageName = UUID().uuidString + ".jpg"
let tempImagePath = tempDir.appendingPathComponent(imageName).path
let tempImagePath = FileOperationManager.shared.generateTempImagePath().path
var lastProcess:Int32 = 0
@ -198,7 +196,21 @@ final class LLMChatViewModel: ObservableObject {
func getLLMRespsonse(draft: DraftMessage) {
Task {
await llmState.setProcessing(true)
await MainActor.run { self.isProcessing = true }
await MainActor.run {
self.isProcessing = true
let emptyMessage = DraftMessage(
text: "",
thinkText: "",
medias: [],
recording: nil,
replyMessage: nil,
createdAt: Date()
)
self.send(draft: emptyMessage, userType: .assistant)
if let lastMessage = self.messages.last {
self.currentStreamingMessageId = lastMessage.id
}
}
var content = draft.text
let medias = draft.medias
@ -209,18 +221,10 @@ final class LLMChatViewModel: ObservableObject {
continue
}
let isInTempDirectory = url.path.contains("/tmp/")
let fileName = url.lastPathComponent
if !isInTempDirectory {
guard let fileUrl = AssetExtractor.copyFileToTmpDirectory(from: url, fileName: fileName) else {
continue
}
let processedUrl = convertHEICImage(from: fileUrl)
content = "<img>\(processedUrl?.path ?? "")</img>" + content
} else {
let processedUrl = convertHEICImage(from: url)
content = "<img>\(processedUrl?.path ?? "")</img>" + content
if let processedUrl = FileOperationManager.shared.processImageFile(from: url, fileName: fileName) {
content = "<img>\(processedUrl.path)</img>" + content
}
}
@ -232,13 +236,36 @@ final class LLMChatViewModel: ObservableObject {
let convertedContent = self.convertDeepSeekMutliChat(content: content)
await llmState.processContent(convertedContent, llm: self.llm) { [weak self] output in
await llmState.processContent(convertedContent, llm: self.llm, showPerformance: true) { [weak self] output in
guard let self = self else { return }
if output.contains("<eop>") {
// force flush
Task {
await UIUpdateOptimizer.shared.forceFlush { finalOutput in
if !finalOutput.isEmpty {
self.send(draft: DraftMessage(
text: finalOutput,
thinkText: "",
medias: [],
recording: nil,
replyMessage: nil,
createdAt: Date()
), userType: .assistant)
}
}
await MainActor.run {
self.isProcessing = false
self.currentStreamingMessageId = nil
}
await self.llmState.setProcessing(false)
}
return
}
Task { @MainActor in
if (output.contains("<eop>")) {
self?.isProcessing = false
await self?.llmState.setProcessing(false)
} else {
self?.send(draft: DraftMessage(
await UIUpdateOptimizer.shared.addUpdate(output) { output in
self.send(draft: DraftMessage(
text: output,
thinkText: "",
medias: [],
@ -248,6 +275,7 @@ final class LLMChatViewModel: ObservableObject {
), userType: .assistant)
}
}
}
}
}
@ -259,7 +287,7 @@ final class LLMChatViewModel: ObservableObject {
}
private func convertDeepSeekMutliChat(content: String) -> String {
if self.modelInfo.name.lowercased().contains("deepseek") {
if self.modelInfo.modelName.lowercased().contains("deepseek") {
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
*/
@ -286,14 +314,11 @@ final class LLMChatViewModel: ObservableObject {
}
}
private func convertHEICImage(from url: URL) -> URL? {
var fileUrl = url
if fileUrl.isHEICImage() {
if let convertedUrl = AssetExtractor.convertHEICToJPG(heicUrl: fileUrl) {
fileUrl = convertedUrl
}
}
return fileUrl
// MARK: - Public Methods for File Operations
/// Cleans the model temporary folder using FileOperationManager
func cleanModelTmpFolder() {
FileOperationManager.shared.cleanModelTempFolder(modelPath: modelInfo.localPath)
}
func onStart() {
@ -311,14 +336,18 @@ final class LLMChatViewModel: ObservableObject {
func onStop() {
ChatHistoryManager.shared.saveChat(
historyId: historyId,
modelId: modelInfo.modelId,
modelName: modelInfo.name,
modelId: modelInfo.id,
modelName: modelInfo.modelName,
messages: messages
)
interactor.disconnect()
llm = nil
self.cleanTmpFolder()
FileOperationManager.shared.cleanTempDirectories()
if !useMmap {
FileOperationManager.shared.cleanModelTempFolder(modelPath: modelInfo.localPath)
}
}
func loadMoreMessage(before message: Message) {
@ -326,40 +355,4 @@ final class LLMChatViewModel: ObservableObject {
.sink { _ in }
.store(in: &subscriptions)
}
func cleanModelTmpFolder() {
let tmpFolderURL = URL(fileURLWithPath: self.modelInfo.localPath).appendingPathComponent("temp")
self.cleanFolder(tmpFolderURL: tmpFolderURL)
}
private func cleanTmpFolder() {
let fileManager = FileManager.default
let tmpDirectoryURL = fileManager.temporaryDirectory
self.cleanFolder(tmpFolderURL: tmpDirectoryURL)
if !useMmap {
cleanModelTmpFolder()
}
}
private func cleanFolder(tmpFolderURL: URL) {
let fileManager = FileManager.default
do {
let files = try fileManager.contentsOfDirectory(at: tmpFolderURL, includingPropertiesForKeys: nil)
for file in files {
if !file.absoluteString.lowercased().contains("networkdownload") {
do {
try fileManager.removeItem(at: file)
print("Deleted file: \(file.path)")
} catch {
print("Error deleting file: \(file.path), \(error.localizedDescription)")
}
}
}
} catch {
print("Error accessing tmp directory: \(error.localizedDescription)")
}
}
}

View File

@ -17,12 +17,14 @@ struct LLMChatView: View {
private let title: String
private let modelPath: String
private let recorderSettings = RecorderSettings(audioFormatID: kAudioFormatLinearPCM, sampleRate: 44100, numberOfChannels: 2, linearPCMBitDepth: 16)
private let recorderSettings = RecorderSettings(audioFormatID: kAudioFormatLinearPCM,
sampleRate: 44100, numberOfChannels: 2,
linearPCMBitDepth: 16)
@State private var showSettings = false
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
self.title = modelInfo.name
self.title = modelInfo.modelName
self.modelPath = modelInfo.localPath
let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history)
_viewModel = StateObject(wrappedValue: viewModel)
@ -32,6 +34,15 @@ struct LLMChatView: View {
ChatView(messages: viewModel.messages, chatType: .conversation) { draft in
viewModel.sendToLLM(draft: draft)
}
messageBuilder: { message, positionInGroup, positionInCommentsGroup, showContextMenuClosure, messageActionClosure, showAttachmentClosure in
LLMChatMessageView(
message: message,
positionInGroup: positionInGroup,
showContextMenuClosure: showContextMenuClosure,
messageActionClosure: messageActionClosure,
showAttachmentClosure: showAttachmentClosure
)
}
.setAvailableInput(
self.title.lowercased().contains("vl") ? .textAndMedia :
self.title.lowercased().contains("audio") ? .textAndAudio :
@ -109,4 +120,28 @@ struct LLMChatView: View {
}
.onDisappear(perform: viewModel.onStop)
}
// MARK: - LLM Chat Message Builder
@ViewBuilder
private func LLMChatMessageView(
message: Message,
positionInGroup: PositionInUserGroup,
showContextMenuClosure: @escaping () -> Void,
messageActionClosure: @escaping (Message, DefaultMessageMenuAction) -> Void,
showAttachmentClosure: @escaping (Attachment) -> Void
) -> some View {
LLMMessageView(
message: message,
positionInGroup: positionInGroup,
isAssistantMessage: !message.user.isCurrentUser,
isStreamingMessage: viewModel.currentStreamingMessageId == message.id,
showContextMenuClosure: {
if !viewModel.isProcessing {
showContextMenuClosure()
}
},
messageActionClosure: messageActionClosure,
showAttachmentClosure: showAttachmentClosure
)
}
}

View File

@ -0,0 +1,338 @@
//
// LLMMessageTextView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/7.
//
import SwiftUI
import MarkdownUI
/**
* LLMMessageTextView - A specialized text view designed for LLM chat messages with typewriter animation
*
* This SwiftUI component provides an enhanced text display specifically designed for AI chat applications.
* It supports both plain text and Markdown rendering with an optional typewriter animation effect
* that creates a dynamic, engaging user experience during AI response streaming.
*
* Key Features:
* - Typewriter animation for streaming AI responses
* - Markdown support with custom styling
* - Smart animation control based on message type and content length
* - Automatic animation management with proper cleanup
* - Performance-optimized character-by-character rendering
*
* Usage Examples:
*
* 1. Basic AI Message with Typewriter Effect:
* ```swift
* LLMMessageTextView(
* text: "Hello! This is an AI response with typewriter animation.",
* messageUseMarkdown: false,
* messageId: "msg_001",
* isAssistantMessage: true,
* isStreamingMessage: true
* )
* ```
*
* 2. Markdown Message with Custom Styling:
* ```swift
* LLMMessageTextView(
* text: "**Bold text** and *italic text* with `code blocks`",
* messageUseMarkdown: true,
* messageId: "msg_002",
* isAssistantMessage: true,
* isStreamingMessage: true
* )
* ```
*
* 3. User Message (No Animation):
* ```swift
* LLMMessageTextView(
* text: "This is a user message",
* messageUseMarkdown: false,
* messageId: "msg_003",
* isAssistantMessage: false,
* isStreamingMessage: false
* )
* ```
*
* Animation Configuration:
* - typingSpeed: 0.015 seconds per character (adjustable)
* - chunkSize: 1 character per animation frame
* - Minimum text length for animation: 5 characters
* - Auto-cleanup on view disappear or streaming completion
*/
struct LLMMessageTextView: View {
let text: String?
let messageUseMarkdown: Bool
let messageId: String
let isAssistantMessage: Bool
let isStreamingMessage: Bool // Whether this message is currently being streamed
@State private var displayedText: String = ""
@State private var animationTimer: Timer?
// Typewriter animation configuration
private let typingSpeed: TimeInterval = 0.015 // Time interval per character
private let chunkSize: Int = 1 // Number of characters to display per frame
init(text: String?,
messageUseMarkdown: Bool = false,
messageId: String,
isAssistantMessage: Bool = false,
isStreamingMessage: Bool = false) {
self.text = text
self.messageUseMarkdown = messageUseMarkdown
self.messageId = messageId
self.isAssistantMessage = isAssistantMessage
self.isStreamingMessage = isStreamingMessage
}
var body: some View {
Group {
if let text = text, !text.isEmpty {
if isAssistantMessage && isStreamingMessage && shouldUseTypewriter {
typewriterView(text)
} else {
staticView(text)
}
}
}
.onAppear {
if let text = text, isAssistantMessage && isStreamingMessage && shouldUseTypewriter {
startTypewriterAnimation(for: text)
} else if let text = text {
displayedText = text
}
}
.onDisappear {
stopAnimation()
}
.onChange(of: text) { oldText, newText in
handleTextChange(newText)
}
.onChange(of: isStreamingMessage) { oldIsStreaming, newIsStreaming in
if !newIsStreaming {
// Streaming ended, display complete text
if let text = text {
displayedText = text
}
stopAnimation()
}
}
}
/**
* Determines whether typewriter animation should be used
*
* Animation is enabled only for assistant messages with more than 5 characters
* to avoid unnecessary animation for short responses.
*/
private var shouldUseTypewriter: Bool {
// Enable typewriter effect only for assistant messages with sufficient length
return isAssistantMessage && (text?.count ?? 0) > 5
}
/**
* Renders text with typewriter animation effect
*
* - Parameter text: The complete text to be animated
* - Returns: A view displaying the animated text with optional Markdown support
*/
@ViewBuilder
private func typewriterView(_ text: String) -> some View {
if messageUseMarkdown {
Markdown(displayedText)
.markdownBlockStyle(\.blockquote) { configuration in
configuration.label
.padding()
.markdownTextStyle {
FontSize(13)
FontWeight(.light)
BackgroundColor(nil)
}
.overlay(alignment: .leading) {
Rectangle()
.fill(Color.gray)
.frame(width: 4)
}
.background(Color.gray.opacity(0.2))
}
} else {
Text(displayedText)
}
}
/**
* Renders static text without animation
*
* - Parameter text: The text to be displayed
* - Returns: A view displaying the complete text with optional Markdown support
*/
@ViewBuilder
private func staticView(_ text: String) -> some View {
if messageUseMarkdown {
Markdown(text)
.markdownBlockStyle(\.blockquote) { configuration in
configuration.label
.padding()
.markdownTextStyle {
FontSize(13)
FontWeight(.light)
BackgroundColor(nil)
}
.overlay(alignment: .leading) {
Rectangle()
.fill(Color.gray)
.frame(width: 4)
}
.background(Color.gray.opacity(0.2))
}
} else {
Text(text)
}
}
/**
* Handles text content changes during streaming
*
* This method intelligently manages animation continuation, restart, or direct display
* based on the relationship between old and new text content.
*
* - Parameter newText: The updated text content
*/
private func handleTextChange(_ newText: String?) {
guard let newText = newText else {
displayedText = ""
stopAnimation()
return
}
if isAssistantMessage && isStreamingMessage && shouldUseTypewriter {
// Check if new text is an extension of current displayed text
if newText.hasPrefix(displayedText) && newText != displayedText {
// Continue typewriter animation
continueTypewriterAnimation(with: newText)
} else if newText != displayedText {
// Restart animation with new content
restartTypewriterAnimation(with: newText)
}
} else {
// Display text directly without animation
displayedText = newText
stopAnimation()
}
}
/**
* Initiates typewriter animation for the given text
*
* - Parameter text: The text to animate
*/
private func startTypewriterAnimation(for text: String) {
displayedText = ""
continueTypewriterAnimation(with: text)
}
/**
* Continues or resumes typewriter animation
*
* This method sets up a timer-based animation that progressively reveals
* characters at the configured typing speed.
*
* - Parameter text: The complete text to animate
*/
private func continueTypewriterAnimation(with text: String) {
guard displayedText.count < text.count else { return }
stopAnimation()
animationTimer = Timer.scheduledTimer(withTimeInterval: typingSpeed, repeats: true) { timer in
DispatchQueue.main.async {
self.appendNextCharacters(from: text)
}
}
}
/**
* Restarts typewriter animation with new content
*
* - Parameter text: The new text to animate
*/
private func restartTypewriterAnimation(with text: String) {
stopAnimation()
displayedText = ""
startTypewriterAnimation(for: text)
}
/**
* Appends the next character(s) to the displayed text
*
* This method is called by the animation timer to progressively reveal
* text characters. It handles proper string indexing and animation completion.
*
* - Parameter text: The source text to extract characters from
*/
private func appendNextCharacters(from text: String) {
let currentLength = displayedText.count
guard currentLength < text.count else {
stopAnimation()
return
}
let endIndex = min(currentLength + chunkSize, text.count)
let startIndex = text.index(text.startIndex, offsetBy: currentLength)
let targetIndex = text.index(text.startIndex, offsetBy: endIndex)
let newChars = text[startIndex..<targetIndex]
displayedText.append(String(newChars))
if displayedText.count >= text.count {
stopAnimation()
}
}
/**
* Stops and cleans up the typewriter animation
*
* This method should be called when animation is no longer needed
* to prevent memory leaks and unnecessary timer execution.
*/
private func stopAnimation() {
animationTimer?.invalidate()
animationTimer = nil
}
}
// MARK: - Preview Provider
struct LLMMessageTextView_Previews: PreviewProvider {
static var previews: some View {
VStack(spacing: 20) {
LLMMessageTextView(
text: "This is a typewriter animation demo text. Hello, this demonstrates the streaming effect!",
messageUseMarkdown: false,
messageId: "test1",
isAssistantMessage: true,
isStreamingMessage: true
)
LLMMessageTextView(
text: "**Bold text** and *italic text* with markdown support.",
messageUseMarkdown: true,
messageId: "test2",
isAssistantMessage: true,
isStreamingMessage: true
)
LLMMessageTextView(
text: "Regular user message without animation.",
messageUseMarkdown: false,
messageId: "test3",
isAssistantMessage: false,
isStreamingMessage: false
)
}
.padding()
}
}

View File

@ -0,0 +1,275 @@
//
// LLMMessageView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/7.
//
import SwiftUI
import Foundation
import SwiftUI
import ExyteChat
// MARK: - Custom Message View
struct LLMMessageView: View {
let message: Message
let positionInGroup: PositionInUserGroup
let isAssistantMessage: Bool
let isStreamingMessage: Bool
let showContextMenuClosure: () -> Void
let messageActionClosure: (Message, DefaultMessageMenuAction) -> Void
let showAttachmentClosure: (Attachment) -> Void
let theme = ChatTheme(
colors: .init(
messageMyBG: .customBlue.opacity(0.2),
messageFriendBG: .clear
),
images: .init(
attach: Image(systemName: "photo"),
attachCamera: Image("attachCamera", bundle: .current)
)
)
@State var avatarViewSize: CGSize = .zero
@State var timeSize: CGSize = .zero
static let widthWithMedia: CGFloat = 204
static let horizontalNoAvatarPadding: CGFloat = 8
static let horizontalAvatarPadding: CGFloat = 8
static let horizontalTextPadding: CGFloat = 12
static let horizontalAttachmentPadding: CGFloat = 1
static let horizontalBubblePadding: CGFloat = 70
var additionalMediaInset: CGFloat {
message.attachments.count > 1 ? LLMMessageView.horizontalAttachmentPadding * 2 : 0
}
var showAvatar: Bool {
positionInGroup == .single
|| positionInGroup == .last
}
var topPadding: CGFloat {
positionInGroup == .single || positionInGroup == .first ? 8 : 4
}
var bottomPadding: CGFloat {
0
}
var body: some View {
HStack(alignment: .top, spacing: 0) {
if !message.user.isCurrentUser {
avatarView
}
VStack(alignment: message.user.isCurrentUser ? .trailing : .leading, spacing: 2) {
bubbleView(message)
}
}
.padding(.top, topPadding)
.padding(.bottom, bottomPadding)
.padding(.trailing, message.user.isCurrentUser ? LLMMessageView.horizontalNoAvatarPadding : 0)
.padding(message.user.isCurrentUser ? .leading : .trailing, message.user.isCurrentUser ? LLMMessageView.horizontalBubblePadding : 0)
.frame(maxWidth: UIScreen.main.bounds.width, alignment: message.user.isCurrentUser ? .trailing : .leading)
.contentShape(Rectangle())
.onLongPressGesture {
showContextMenuClosure()
}
}
@ViewBuilder
func bubbleView(_ message: Message) -> some View {
VStack(alignment: .leading, spacing: 0) {
if !message.attachments.isEmpty {
attachmentsView(message)
}
if !message.text.isEmpty {
textWithTimeView(message)
}
if let recording = message.recording {
VStack(alignment: .trailing, spacing: 8) {
recordingView(recording)
messageTimeView()
.padding(.bottom, 8)
.padding(.trailing, 12)
}
}
}
.bubbleBackground(message, theme: theme)
}
@ViewBuilder
var avatarView: some View {
Group {
if showAvatar {
AsyncImage(url: message.user.avatarURL) { image in
image
.resizable()
.aspectRatio(contentMode: .fill)
} placeholder: {
Circle()
.fill(Color.gray.opacity(0.3))
}
.frame(width: 32, height: 32)
.clipShape(Circle())
.contentShape(Circle())
} else {
Color.clear.frame(width: 32, height: 32)
}
}
.padding(.horizontal, LLMMessageView.horizontalAvatarPadding)
.sizeGetter($avatarViewSize)
}
@ViewBuilder
func attachmentsView(_ message: Message) -> some View {
ForEach(message.attachments, id: \.id) { attachment in
AsyncImage(url: attachment.thumbnail) { image in
image
.resizable()
.aspectRatio(contentMode: .fit)
} placeholder: {
Rectangle()
.fill(Color.gray.opacity(0.3))
}
.frame(maxWidth: LLMMessageView.widthWithMedia, maxHeight: 200)
.cornerRadius(12)
.onTapGesture {
showAttachmentClosure(attachment)
}
}
.applyIf(message.attachments.count > 1) {
$0
.padding(.top, LLMMessageView.horizontalAttachmentPadding)
.padding(.horizontal, LLMMessageView.horizontalAttachmentPadding)
}
.overlay(alignment: .bottomTrailing) {
if message.text.isEmpty {
messageTimeView(needsCapsule: true)
.padding(4)
}
}
.contentShape(Rectangle())
}
@ViewBuilder
func textWithTimeView(_ message: Message) -> some View {
// Message View with Type Writer Animation
let messageView = LLMMessageTextView(
text: message.text,
messageUseMarkdown: true,
messageId: message.id,
isAssistantMessage: isAssistantMessage,
isStreamingMessage: isStreamingMessage
)
.fixedSize(horizontal: false, vertical: true)
.padding(.horizontal, LLMMessageView.horizontalTextPadding)
HStack(alignment: .lastTextBaseline, spacing: 12) {
messageView
if !message.attachments.isEmpty {
Spacer()
}
}
.padding(.vertical, 8)
}
@ViewBuilder
func recordingView(_ recording: Recording) -> some View {
HStack {
Image(systemName: "mic.fill")
.foregroundColor(.blue)
Text("Audio Message")
.font(.caption)
.foregroundColor(.secondary)
}
.padding(.horizontal, LLMMessageView.horizontalTextPadding)
.padding(.top, 8)
}
func messageTimeView(needsCapsule: Bool = false) -> some View {
Group {
if needsCapsule {
Text(DateFormatter.timeFormatter.string(from: message.createdAt))
.font(.caption)
.foregroundColor(.white)
.opacity(0.8)
.padding(.top, 4)
.padding(.bottom, 4)
.padding(.horizontal, 8)
.background {
Capsule()
.foregroundColor(.black.opacity(0.4))
}
} else {
Text(DateFormatter.timeFormatter.string(from: message.createdAt))
.font(.caption)
.foregroundColor(message.user.isCurrentUser ? theme.colors.messageMyTimeText : theme.colors.messageFriendTimeText)
}
}
.sizeGetter($timeSize)
}
}
// MARK: - View Extensions
extension View {
@ViewBuilder
func sizeGetter(_ size: Binding<CGSize>) -> some View {
self.background(
GeometryReader { geometry in
Color.clear
.preference(key: SizePreferenceKey.self, value: geometry.size)
}
)
.onPreferenceChange(SizePreferenceKey.self) { newSize in
size.wrappedValue = newSize
}
}
@ViewBuilder
func applyIf<Content: View>(_ condition: Bool, transform: (Self) -> Content) -> some View {
if condition {
transform(self)
} else {
self
}
}
}
// MARK: - Preference Key
struct SizePreferenceKey: PreferenceKey {
static var defaultValue: CGSize = .zero
static func reduce(value: inout CGSize, nextValue: () -> CGSize) {}
}
// MARK: - Date Formatter Extension
extension DateFormatter {
static let timeFormatter: DateFormatter = {
let formatter = DateFormatter()
formatter.timeStyle = .short
return formatter
}()
}
extension View {
@ViewBuilder
func bubbleBackground(_ message: Message, theme: ChatTheme, isReply: Bool = false) -> some View {
let radius: CGFloat = !message.attachments.isEmpty ? 12 : 20
let additionalMediaInset: CGFloat = message.attachments.count > 1 ? 2 : 0
self
.frame(width: message.attachments.isEmpty ? nil : LLMMessageView.widthWithMedia + additionalMediaInset)
.foregroundColor(message.user.isCurrentUser ? theme.colors.messageMyText : theme.colors.messageFriendText)
.background {
if isReply || !message.text.isEmpty || message.recording != nil {
RoundedRectangle(cornerRadius: radius)
.foregroundColor(message.user.isCurrentUser ? theme.colors.messageMyBG : theme.colors.messageFriendBG)
.opacity(isReply ? 0.5 : 1)
}
}
.cornerRadius(radius)
}
}

View File

@ -37,7 +37,7 @@ struct ModelSettingsView: View {
var body: some View {
NavigationView {
Form {
Section(header: Text("Model Configuration")) {
Section {
Toggle("Use mmap", isOn: $viewModel.useMmap)
.onChange(of: viewModel.useMmap) { newValue in
viewModel.modelConfigManager.updateUseMmap(newValue)
@ -47,11 +47,13 @@ struct ModelSettingsView: View {
viewModel.cleanModelTmpFolder()
showAlert = true
}
} header: {
Text("Model Configuration")
}
// Diffusion Settings
if viewModel.isDiffusionModel {
Section(header: Text("Diffusion Settings")) {
Section {
Stepper(value: $iterations, in: 1...100) {
HStack {
Text("Iterations")
@ -84,9 +86,11 @@ struct ModelSettingsView: View {
}
}
}
} header: {
Text("Diffusion Settings")
}
} else {
Section(header: Text("Sampling Strategy")) {
Section {
Picker("Sampler Type", selection: $selectedSampler) {
ForEach(SamplerType.allCases, id: \.self) { sampler in
Text(sampler.displayName)
@ -218,6 +222,8 @@ struct ModelSettingsView: View {
default:
EmptyView()
}
} header: {
Text("Sampling Strategy")
}
}
}
@ -272,5 +278,3 @@ struct ModelSettingsView: View {
viewModel.modelConfigManager.updateMixedSamplers(orderedSelection)
}
}

View File

@ -1,43 +0,0 @@
//
// ChatHistoryItemView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/16.
//
import SwiftUI
struct ChatHistoryItemView: View {
let history: ChatHistory
var body: some View {
HStack(spacing: 12) {
ModelIconView(modelId: history.modelId)
.frame(width: 36, height: 36)
.clipShape(Circle())
VStack(alignment: .leading, spacing: 4) {
if let firstMessage = history.messages.last {
Text(String(firstMessage.content.prefix(50)) + "...")
.lineLimit(2)
.font(.system(size: 14))
}
HStack {
VStack(alignment: .leading) {
Text(history.modelName)
.font(.system(size: 12))
.foregroundColor(.gray)
Text(history.updatedAt.formatAgo())
.font(.system(size: 10))
.foregroundColor(.gray)
}
}
}
}
.padding(.vertical, 8)
}
}

View File

@ -1,88 +0,0 @@
//
// SideMenuView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/16.
//
import SwiftUI
struct SideMenuView: View {
@Binding var isOpen: Bool
@Binding var selectedHistory: ChatHistory?
@Binding var histories: [ChatHistory]
@State private var showingAlert = false
@State private var historyToDelete: ChatHistory?
@State private var dragOffset: CGFloat = 0
var body: some View {
GeometryReader { geometry in
VStack {
HStack {
Text(NSLocalizedString("ChatHistroyTitle", comment: "Chat Histroy Title"))
.fontWeight(.medium)
.font(.system(size: 20))
Spacer()
}
.padding(.top, 80)
.padding(.leading)
List {
ForEach(histories.sorted(by: { $0.updatedAt > $1.updatedAt })) { history in
ChatHistoryItemView(history: history)
.onTapGesture {
selectedHistory = history
isOpen = false
}
.onLongPressGesture {
historyToDelete = history
showingAlert = true
}
.listRowBackground(Color.sidemenuBg)
}
}
.background(Color.sidemenuBg)
.listStyle(PlainListStyle())
}
.background(Color.sidemenuBg)
.frame(width: geometry.size.width * 0.8)
.offset(x: isOpen ? 0 : -geometry.size.width * 0.8)
.animation(.easeOut, value: isOpen)
.gesture(
DragGesture()
.onChanged { value in
if value.translation.width < 0 {
dragOffset = value.translation.width
}
}
.onEnded { value in
if value.translation.width < -geometry.size.width * 0.25 {
isOpen = false
}
dragOffset = 0
}
)
.alert("Delete History", isPresented: $showingAlert) {
Button("Cancel", role: .cancel) {}
Button("Delete", role: .destructive) {
if let history = historyToDelete {
deleteHistory(history)
}
}
} message: {
Text("Are you sure you want to delete this history?")
}
}
}
private func deleteHistory(_ history: ChatHistory) {
//
ChatHistoryManager.shared.deleteHistory(history)
//
histories.removeAll { $0.id == history.id }
}
}

View File

@ -0,0 +1,45 @@
//
// ChatHistoryItemView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/16.
//
import SwiftUI
struct ChatHistoryItemView: View {
let history: ChatHistory
var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let firstMessage = history.messages.last {
Text(String(firstMessage.content.prefix(200)))
.lineLimit(1)
.font(.system(size: 15, weight: .medium))
.foregroundColor(.primary)
}
HStack(alignment: .bottom) {
ModelIconView(modelId: history.modelId)
.frame(width: 20, height: 20)
.clipShape(Circle())
.padding(.trailing, 0)
Text(history.modelName)
.lineLimit(1)
.font(.system(size: 12, weight: .semibold))
.foregroundColor(.black.opacity(0.5))
Spacer()
Text(history.updatedAt.formatAgo())
.font(.system(size: 12, weight: .regular))
.foregroundColor(.black.opacity(0.5))
}
}
.padding(.vertical, 10)
.padding(.horizontal, 0)
}
}

View File

@ -0,0 +1,136 @@
//
// SideMenuView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/16.
//
import SwiftUI
struct SideMenuView: View {
@Binding var isOpen: Bool
@Binding var selectedHistory: ChatHistory?
@Binding var histories: [ChatHistory]
@Binding var navigateToMainSettings: Bool
@State private var showingAlert = false
@State private var historyToDelete: ChatHistory?
@State private var navigateToSettings = false
@State private var dragOffset: CGFloat = 0
var body: some View {
ZStack {
GeometryReader { geometry in
VStack {
HStack {
Text(NSLocalizedString("ChatHistroyTitle", comment: "Chat Histroy Title"))
.fontWeight(.medium)
.font(.system(size: 20))
Spacer()
}
.padding(.top, 80)
.padding(.leading, 12)
List {
ForEach(histories.sorted(by: { $0.updatedAt > $1.updatedAt })) { history in
ChatHistoryItemView(history: history)
.onTapGesture {
selectedHistory = history
isOpen = false
}
.onLongPressGesture {
historyToDelete = history
showingAlert = true
}
.listRowBackground(Color.sidemenuBg)
.listRowSeparator(.hidden)
}
}
.background(Color.sidemenuBg)
.listStyle(PlainListStyle())
Spacer()
HStack {
Button(action: {
isOpen = false
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
navigateToMainSettings = true
}
}) {
HStack {
Image(systemName: "gear")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 20, height: 20)
}
.foregroundColor(.primary)
.padding(.leading)
}
Spacer()
}
.padding(EdgeInsets(top: 10, leading: 12, bottom: 30, trailing: 0))
}
.background(Color.sidemenuBg)
.frame(width: geometry.size.width * 0.8)
.offset(x: isOpen ? 0 : -geometry.size.width * 0.8)
.animation(.easeOut, value: isOpen)
.gesture(
DragGesture()
.onChanged { value in
if value.translation.width < 0 {
dragOffset = value.translation.width
}
}
.onEnded { value in
if value.translation.width < -geometry.size.width * 0.25 {
isOpen = false
}
dragOffset = 0
}
)
.alert("Delete History", isPresented: $showingAlert) {
Button("Cancel", role: .cancel) {}
Button(LocalizedStringKey("button.delete"), role: .destructive) {
if let history = historyToDelete {
deleteHistory(history)
}
}
} message: {
Text("Are you sure you want to delete this history?")
}
}
}
}
private func deleteHistory(_ history: ChatHistory) {
ChatHistoryManager.shared.deleteHistory(history)
histories.removeAll { $0.id == history.id }
}
}
struct SettingsFullScreenView: View {
@Binding var isPresented: Bool
var body: some View {
NavigationView {
SettingsView()
.navigationTitle("Settings")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .navigationBarLeading) {
Button(action: {
isPresented = false
}) {
Image(systemName: "xmark")
.foregroundColor(.primary)
.fontWeight(.medium)
}
}
}
}
.navigationViewStyle(StackNavigationViewStyle())
}
}

View File

@ -0,0 +1,209 @@
//
// LLMInferenceEngineWrapper.h
// mnn-llm
//
// Created by wangzhaode on 2023/12/14.
//
#ifndef LLMInferenceEngineWrapper_h
#define LLMInferenceEngineWrapper_h
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
typedef void (^CompletionHandler)(BOOL success);
typedef void (^OutputHandler)(NSString * _Nonnull output);
// MARK: - Benchmark Related Types
/**
* Progress type enumeration for structured benchmark reporting
*/
typedef NS_ENUM(NSInteger, BenchmarkProgressType) {
BenchmarkProgressTypeUnknown = 0,
BenchmarkProgressTypeInitializing = 1,
BenchmarkProgressTypeWarmingUp = 2,
BenchmarkProgressTypeRunningTest = 3,
BenchmarkProgressTypeProcessingResults = 4,
BenchmarkProgressTypeCompleted = 5,
BenchmarkProgressTypeStopping = 6
};
/**
* Structured progress information for benchmark
*/
@interface BenchmarkProgressInfo : NSObject
@property (nonatomic, assign) NSInteger progress; // 0-100
@property (nonatomic, strong) NSString *statusMessage; // Status description
@property (nonatomic, assign) BenchmarkProgressType progressType;
@property (nonatomic, assign) NSInteger currentIteration;
@property (nonatomic, assign) NSInteger totalIterations;
@property (nonatomic, assign) NSInteger nPrompt;
@property (nonatomic, assign) NSInteger nGenerate;
@property (nonatomic, assign) float runTimeSeconds;
@property (nonatomic, assign) float prefillTimeSeconds;
@property (nonatomic, assign) float decodeTimeSeconds;
@property (nonatomic, assign) float prefillSpeed;
@property (nonatomic, assign) float decodeSpeed;
@end
/**
* Benchmark result structure
*/
@interface BenchmarkResult : NSObject
@property (nonatomic, assign) BOOL success;
@property (nonatomic, strong, nullable) NSString *errorMessage;
@property (nonatomic, strong) NSArray<NSNumber *> *prefillTimesUs;
@property (nonatomic, strong) NSArray<NSNumber *> *decodeTimesUs;
@property (nonatomic, strong) NSArray<NSNumber *> *sampleTimesUs;
@property (nonatomic, assign) NSInteger promptTokens;
@property (nonatomic, assign) NSInteger generateTokens;
@property (nonatomic, assign) NSInteger repeatCount;
@property (nonatomic, assign) BOOL kvCacheEnabled;
@end
// Benchmark callback blocks
typedef void (^BenchmarkProgressCallback)(BenchmarkProgressInfo *progressInfo);
typedef void (^BenchmarkErrorCallback)(NSString *error);
typedef void (^BenchmarkIterationCompleteCallback)(NSString *detailedStats);
typedef void (^BenchmarkCompleteCallback)(BenchmarkResult *result);
/**
* LLMInferenceEngineWrapper - A high-level Objective-C wrapper for MNN LLM inference engine
*
* This class provides a convenient interface for integrating MNN's Large Language Model
* inference capabilities into iOS applications with enhanced error handling, performance
* optimization, and thread safety.
*/
@interface LLMInferenceEngineWrapper : NSObject
/**
* Initialize the LLM inference engine with a model path
*
* @param modelPath The file system path to the model directory
* @param completion Completion handler called with success/failure status
* @return Initialized instance of LLMInferenceEngineWrapper
*/
- (instancetype)initWithModelPath:(NSString *)modelPath completion:(CompletionHandler)completion;
/**
* Process user input and generate streaming LLM response
*
* @param input The user's input text to process
* @param output Callback block that receives streaming output chunks
*/
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output;
/**
* Process user input and generate streaming LLM response with optional performance output
*
* @param input The user's input text to process
* @param output Callback block that receives streaming output chunks
* @param showPerformance Whether to output performance statistics after response completion
*/
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output showPerformance:(BOOL)showPerformance;
/**
* Add chat prompts from an array of dictionaries to the conversation history
*
* @param array NSArray containing NSDictionary objects with chat messages
*/
- (void)addPromptsFromArray:(NSArray<NSDictionary *> *)array;
/**
* Set the configuration for the LLM engine using a JSON string
*
* @param jsonStr JSON string containing configuration parameters
*/
- (void)setConfigWithJSONString:(NSString *)jsonStr;
/**
* Check if model is ready for inference
*
* @return YES if model is loaded and ready
*/
- (BOOL)isModelReady;
/**
* Get current processing status
*
* @return YES if currently processing an inference request
*/
- (BOOL)isProcessing;
/**
* Cancel ongoing inference (if supported)
*/
- (void)cancelInference;
/**
* Get chat history count
*
* @return Number of messages in chat history
*/
- (NSUInteger)getChatHistoryCount;
/**
* Clear chat history
*/
- (void)clearChatHistory;
// MARK: - Benchmark Methods
/**
* Run official benchmark following llm_bench.cpp approach
*
* @param backend Backend type (0 for CPU)
* @param threads Number of threads
* @param useMmap Whether to use memory mapping
* @param power Power setting
* @param precision Precision setting (2 for low precision)
* @param memory Memory setting (2 for low memory)
* @param dynamicOption Dynamic optimization option
* @param nPrompt Number of prompt tokens
* @param nGenerate Number of tokens to generate
* @param nRepeat Number of repetitions
* @param kvCache Whether to use KV cache
* @param progressCallback Progress update callback
* @param errorCallback Error callback
* @param iterationCompleteCallback Iteration completion callback
* @param completeCallback Final completion callback
*/
- (void)runOfficialBenchmarkWithBackend:(NSInteger)backend
threads:(NSInteger)threads
useMmap:(BOOL)useMmap
power:(NSInteger)power
precision:(NSInteger)precision
memory:(NSInteger)memory
dynamicOption:(NSInteger)dynamicOption
nPrompt:(NSInteger)nPrompt
nGenerate:(NSInteger)nGenerate
nRepeat:(NSInteger)nRepeat
kvCache:(BOOL)kvCache
progressCallback:(BenchmarkProgressCallback _Nullable)progressCallback
errorCallback:(BenchmarkErrorCallback _Nullable)errorCallback
iterationCompleteCallback:(BenchmarkIterationCompleteCallback _Nullable)iterationCompleteCallback
completeCallback:(BenchmarkCompleteCallback _Nullable)completeCallback;
/**
* Stop running benchmark
*/
- (void)stopBenchmark;
/**
* Check if benchmark is currently running
*
* @return YES if benchmark is running
*/
- (BOOL)isBenchmarkRunning;
@end
NS_ASSUME_NONNULL_END
#endif /* LLMInferenceEngineWrapper_h */

File diff suppressed because it is too large Load Diff

View File

@ -1,33 +0,0 @@
//
// LLMInferenceEngineWrapper.h
// mnn-llm
//
// Created by wangzhaode on 2023/12/14.
//
#ifndef LLMInferenceEngineWrapper_h
#define LLMInferenceEngineWrapper_h
// LLMInferenceEngineWrapper.h
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
typedef void (^CompletionHandler)(BOOL success);
typedef void (^OutputHandler)(NSString * _Nonnull output);
@interface LLMInferenceEngineWrapper : NSObject
- (instancetype)initWithModelPath:(NSString *)modelPath completion:(CompletionHandler)completion;
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output;
- (void)addPromptsFromArray:(NSArray<NSDictionary *> *)array;
- (void)setConfigWithJSONString:(NSString *)jsonStr;
@end
NS_ASSUME_NONNULL_END
#endif /* LLMInferenceEngineWrapper_h */

View File

@ -1,274 +0,0 @@
//
// LLMInferenceEngineWrapper.m
// mnn-llm
//
// Created by wangzhaode on 2023/12/14.
//
#include <iostream>
#include <string>
#include <unistd.h>
#include <sys/stat.h>
#include <filesystem>
#include <functional>
#include <MNN/llm/llm.hpp>
#include <vector>
#include <utility>
#import <Foundation/Foundation.h>
#import "LLMInferenceEngineWrapper.h"
using namespace MNN::Transformer;
using ChatMessage = std::pair<std::string, std::string>;
static std::vector<ChatMessage> history{};
@implementation LLMInferenceEngineWrapper {
std::shared_ptr<Llm> llm;
}
- (instancetype)initWithModelPath:(NSString *)modelPath completion:(CompletionHandler)completion {
self = [super init];
if (self) {
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
BOOL success = [self loadModelFromPath:modelPath];
// MARK: Test Local Model
// BOOL success = [self loadModel];
dispatch_async(dispatch_get_main_queue(), ^{
completion(success);
});
});
}
return self;
}
bool remove_directory(const std::string& path) {
try {
std::filesystem::remove_all(path); // 删除目录及其内容
return true;
} catch (const std::filesystem::filesystem_error& e) {
std::cerr << "Error removing directory: " << e.what() << std::endl;
return false;
}
}
- (BOOL)loadModel {
if (!llm) {
NSString *bundleDirectory = [[NSBundle mainBundle] bundlePath];
std::string model_dir = [bundleDirectory UTF8String];
std::string config_path = model_dir + "/config.json";
llm.reset(Llm::createLLM(config_path));
NSString *tempDirectory = NSTemporaryDirectory();
llm->set_config("{\"tmp_path\":\"" + std::string([tempDirectory UTF8String]) + "\", \"use_mmap\":true}");
llm->load();
}
return YES;
}
- (BOOL)loadModelFromPath:(NSString *)modelPath {
if (!llm) {
std::string config_path = std::string([modelPath UTF8String]) + "/config.json";
// Read the config file to get use_mmap value
NSError *error = nil;
NSData *configData = [NSData dataWithContentsOfFile:[NSString stringWithUTF8String:config_path.c_str()]];
NSDictionary *configDict = [NSJSONSerialization JSONObjectWithData:configData options:0 error:&error];
// If use_mmap key doesn't exist, default to YES
BOOL useMmap = configDict[@"use_mmap"] == nil ? YES : [configDict[@"use_mmap"] boolValue];
llm.reset(Llm::createLLM(config_path));
if (!llm) {
return NO;
}
// Create temp directory inside the modelPath folder
std::string model_path_str([modelPath UTF8String]);
std::string temp_directory_path = model_path_str + "/temp";
struct stat info;
if (stat(temp_directory_path.c_str(), &info) == 0) {
// Directory exists, so remove it
if (!remove_directory(temp_directory_path)) {
std::cerr << "Failed to remove existing temp directory: " << temp_directory_path << std::endl;
return NO;
}
std::cerr << "Existing temp directory removed: " << temp_directory_path << std::endl;
}
// Now create the temp directory
if (mkdir(temp_directory_path.c_str(), 0777) != 0) {
std::cerr << "Failed to create temp directory: " << temp_directory_path << std::endl;
return NO;
}
std::cerr << "Temp directory created: " << temp_directory_path << std::endl;
// NSLog(@"useMmap value: %@", useMmap ? @"YES" : @"NO");
// Explicitly convert BOOL to bool and ensure proper string conversion
bool useMmapCpp = (useMmap == YES);
std::string configStr = "{\"tmp_path\":\"" + temp_directory_path + "\", \"use_mmap\":" + (useMmapCpp ? "true" : "false") + "}";
// Debug print to check the final config string
// NSLog(@"Config string: %s", configStr.c_str());
llm->set_config(configStr);
llm->load();
}
else {
std::cerr << "Warmming:: LLM have already been created!" << std::endl;
}
return YES;
}
- (void)setConfigWithJSONString:(NSString *)jsonStr {
if (!llm) {
return;
}
if (jsonStr) {
const char *cString = [jsonStr UTF8String];
std::string stdString(cString);
llm->set_config(stdString);
} else {
NSLog(@"Error: JSON string is nil or invalid.");
}
}
// llm stream buffer with callback
class LlmStreamBuffer : public std::streambuf {
public:
using CallBack = std::function<void(const char* str, size_t len)>;
LlmStreamBuffer(CallBack callback) : callback_(callback) {}
protected:
virtual std::streamsize xsputn(const char* s, std::streamsize n) override {
if (callback_) {
callback_(s, n);
}
return n;
}
private:
CallBack callback_ = nullptr;
};
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output {
if (llm == nil) {
output(@"Error: Model not loaded");
return;
}
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_LOW, 0), ^{
LlmStreamBuffer::CallBack callback = [output](const char* str, size_t len) {
if (output) {
NSString *nsOutput = [[NSString alloc] initWithBytes:str
length:len
encoding:NSUTF8StringEncoding];
if (nsOutput) {
output(nsOutput);
}
}
};
LlmStreamBuffer streambuf(callback);
std::ostream os(&streambuf);
history.emplace_back(ChatMessage("user", [input UTF8String]));
if (std::string([input UTF8String]) == "benchmark") {
[self performBenchmarkWithOutput:&os];
} else {
llm->response(history, &os, "<eop>", 999999);
}
});
}
// New method to handle benchmarking
- (void)performBenchmarkWithOutput:(std::ostream *)os {
std::string model_dir = [[[NSBundle mainBundle] bundlePath] UTF8String];
std::string prompt_file = model_dir + "/bench.txt";
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
if (prompt.substr(0, 1) == "#") {
continue;
}
std::string::size_type pos = 0;
while ((pos = prompt.find("\\n", pos)) != std::string::npos) {
prompt.replace(pos, 2, "\n");
pos += 1;
}
prompts.push_back(prompt);
}
int prompt_len = 0;
int decode_len = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
auto context = llm->getContext();
for (const auto& p : prompts) {
llm->response(p, os, "\n");
prompt_len += context->prompt_len;
decode_len += context->gen_seq_len;
prefill_time += context->prefill_us;
decode_time += context->decode_us;
}
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
*os << "\n#################################\n"
<< "prompt tokens num = " << prompt_len << "\n"
<< "decode tokens num = " << decode_len << "\n"
<< "prefill time = " << std::fixed << std::setprecision(2) << prefill_s << " s\n"
<< "decode time = " << std::fixed << std::setprecision(2) << decode_s << " s\n"
<< "prefill speed = " << std::fixed << std::setprecision(2) << (prefill_s > 0 ? prompt_len / prefill_s : 0) << " tok/s\n"
<< "decode speed = " << std::fixed << std::setprecision(2) << (decode_s > 0 ? decode_len / decode_s : 0) << " tok/s\n"
<< "##################################\n";
*os << "<eop>";
}
- (void)dealloc {
std::cerr << "llm dealloc reset" << std::endl;
history.clear();
llm.reset();
llm = nil;
}
- (void)init:(const std::vector<std::string>&)chatHistory {
history.clear();
history.emplace_back("system", "You are a helpful assistant.");
for (size_t i = 0; i < chatHistory.size(); ++i) {
history.emplace_back(i % 2 == 0 ? "user" : "assistant", chatHistory[i]);
}
}
- (void)addPromptsFromArray:(NSArray<NSDictionary *> *)array {
history.clear();
for (NSDictionary *dict in array) {
[self addPromptsFromDictionary:dict];
}
}
- (void)addPromptsFromDictionary:(NSDictionary *)dictionary {
for (NSString *key in dictionary) {
NSString *value = dictionary[key];
std::string keyString = [key UTF8String];
std::string valueString = [value UTF8String];
history.emplace_back(ChatMessage(keyString, valueString));
}
}
@end

View File

@ -12,11 +12,15 @@ struct MNNLLMiOSApp: App {
init() {
UIView.appearance().overrideUserInterfaceStyle = .light
let savedLanguage = LanguageManager.shared.currentLanguage
UserDefaults.standard.set([savedLanguage], forKey: "AppleLanguages")
UserDefaults.standard.synchronize()
}
var body: some Scene {
WindowGroup {
ModelListView()
MainTabView()
}
}
}

View File

@ -0,0 +1,144 @@
//
// BenchmarkResultsHelper.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
import Darwin
/**
* Helper class for processing and formatting benchmark test results.
* Provides statistical analysis, formatting utilities, and device information
* for benchmark result display and sharing.
*/
class BenchmarkResultsHelper {
static let shared = BenchmarkResultsHelper()
private init() {}
// MARK: - Results Processing & Statistics
/// Processes test results to generate comprehensive benchmark statistics
/// - Parameter testResults: Array of completed test instances
/// - Returns: Processed statistics including speed metrics and configuration details
func processTestResults(_ testResults: [TestInstance]) -> BenchmarkStatistics {
guard !testResults.isEmpty else {
return BenchmarkStatistics.empty
}
let firstTest = testResults[0]
let configText = "Backend: CPU, Threads: \(firstTest.threads), Memory: Low, Precision: Low"
var prefillStats: SpeedStatistics?
var decodeStats: SpeedStatistics?
var totalTokensProcessed = 0
// Calculate prefill (prompt processing) statistics
let allPrefillSpeeds = testResults.flatMap { test in
test.getTokensPerSecond(tokens: test.nPrompt, timesUs: test.prefillUs)
}
if !allPrefillSpeeds.isEmpty {
let avgPrefill = allPrefillSpeeds.reduce(0, +) / Double(allPrefillSpeeds.count)
let stdevPrefill = calculateStandardDeviation(values: allPrefillSpeeds, mean: avgPrefill)
prefillStats = SpeedStatistics(average: avgPrefill, stdev: stdevPrefill, label: "Prompt Processing")
}
// Calculate decode (token generation) statistics
let allDecodeSpeeds = testResults.flatMap { test in
test.getTokensPerSecond(tokens: test.nGenerate, timesUs: test.decodeUs)
}
if !allDecodeSpeeds.isEmpty {
let avgDecode = allDecodeSpeeds.reduce(0, +) / Double(allDecodeSpeeds.count)
let stdevDecode = calculateStandardDeviation(values: allDecodeSpeeds, mean: avgDecode)
decodeStats = SpeedStatistics(average: avgDecode, stdev: stdevDecode, label: "Token Generation")
}
// Calculate total tokens processed across all tests
totalTokensProcessed = testResults.reduce(0) { sum, test in
return sum + (test.nPrompt * test.prefillUs.count) + (test.nGenerate * test.decodeUs.count)
}
return BenchmarkStatistics(
configText: configText,
prefillStats: prefillStats,
decodeStats: decodeStats,
totalTokensProcessed: totalTokensProcessed,
totalTests: testResults.count
)
}
/// Calculates standard deviation for a set of values
/// - Parameters:
/// - values: Array of numeric values
/// - mean: Pre-calculated mean of the values
/// - Returns: Standard deviation value
private func calculateStandardDeviation(values: [Double], mean: Double) -> Double {
guard values.count > 1 else { return 0.0 }
let variance = values.reduce(0) { sum, value in
let diff = value - mean
return sum + (diff * diff)
} / Double(values.count - 1)
return sqrt(variance)
}
// MARK: - Formatting & Display
/// Formats speed statistics with average and standard deviation
/// - Parameter stats: Speed statistics to format
/// - Returns: Formatted string like "42.5 ± 3.2 tok/s"
func formatSpeedStatisticsLine(_ stats: SpeedStatistics) -> String {
return String(format: "%.1f ± %.1f tok/s", stats.average, stats.stdev)
}
/// Returns the label-only portion of speed statistics
/// - Parameter stats: Speed statistics object
/// - Returns: Human-readable label (e.g., "Prompt Processing")
func formatSpeedLabelOnly(_ stats: SpeedStatistics) -> String {
return stats.label
}
/// Formats model parameter summary for display
/// - Parameters:
/// - totalTokens: Total number of tokens processed
/// - totalTests: Total number of tests completed
/// - Returns: Formatted summary string
func formatModelParams(totalTokens: Int, totalTests: Int) -> String {
return "Total Tokens: \(totalTokens), Tests: \(totalTests)"
}
/// Formats memory usage with percentage and absolute values
/// - Parameters:
/// - maxMemoryKb: Peak memory usage in kilobytes
/// - totalKb: Total system memory in kilobytes
/// - Returns: Tuple containing formatted value and percentage label
func formatMemoryUsage(maxMemoryKb: Int64, totalKb: Int64) -> (valueText: String, labelText: String) {
let maxMemoryMB = Double(maxMemoryKb) / 1024.0
let totalMemoryGB = Double(totalKb) / (1024.0 * 1024.0)
let percentage = (Double(maxMemoryKb) / Double(totalKb)) * 100.0
let valueText = String(format: "%.1f MB", maxMemoryMB)
let labelText = String(format: "%.1f%% of %.1f GB", percentage, totalMemoryGB)
return (valueText, labelText)
}
// MARK: - Device & System Information
/// Gets comprehensive device information including model and iOS version
/// - Returns: Formatted device info string (e.g., "iPhone 14 Pro, iOS 17.0")
func getDeviceInfo() -> String {
return DeviceInfoHelper.shared.getDeviceInfo()
}
/// Gets total system memory in kilobytes
/// - Returns: System memory size in KB
func getTotalSystemMemoryKb() -> Int64 {
return Int64(ProcessInfo.processInfo.physicalMemory) / 1024
}
}

View File

@ -0,0 +1,22 @@
//
// BenchmarkErrorCode.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Enumeration of possible error codes that can occur during benchmark execution.
* Provides specific error identification for different failure scenarios.
*/
enum BenchmarkErrorCode: Int {
case benchmarkFailedUnknown = 30
case testInstanceFailed = 40
case modelNotInitialized = 50
case benchmarkRunning = 99
case benchmarkStopped = 100
case nativeError = 0
case modelError = 2
}

View File

@ -0,0 +1,53 @@
//
// BenchmarkProgress.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Structure containing detailed progress information for benchmark execution.
* Provides real-time metrics including timing data and performance statistics.
*/
struct BenchmarkProgress {
let progress: Int // 0-100
let statusMessage: String
let progressType: ProgressType
let currentIteration: Int
let totalIterations: Int
let nPrompt: Int
let nGenerate: Int
let runTimeSeconds: Float
let prefillTimeSeconds: Float
let decodeTimeSeconds: Float
let prefillSpeed: Float
let decodeSpeed: Float
init(progress: Int,
statusMessage: String,
progressType: ProgressType = .unknown,
currentIteration: Int = 0,
totalIterations: Int = 0,
nPrompt: Int = 0,
nGenerate: Int = 0,
runTimeSeconds: Float = 0.0,
prefillTimeSeconds: Float = 0.0,
decodeTimeSeconds: Float = 0.0,
prefillSpeed: Float = 0.0,
decodeSpeed: Float = 0.0) {
self.progress = progress
self.statusMessage = statusMessage
self.progressType = progressType
self.currentIteration = currentIteration
self.totalIterations = totalIterations
self.nPrompt = nPrompt
self.nGenerate = nGenerate
self.runTimeSeconds = runTimeSeconds
self.prefillTimeSeconds = prefillTimeSeconds
self.decodeTimeSeconds = decodeTimeSeconds
self.prefillSpeed = prefillSpeed
self.decodeSpeed = decodeSpeed
}
}

View File

@ -0,0 +1,24 @@
//
// BenchmarkResult.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Structure containing the results of a completed benchmark test.
* Encapsulates test instance data along with success status and error information.
*/
struct BenchmarkResult {
let testInstance: TestInstance
let success: Bool
let errorMessage: String?
init(testInstance: TestInstance, success: Bool, errorMessage: String? = nil) {
self.testInstance = testInstance
self.success = success
self.errorMessage = errorMessage
}
}

View File

@ -0,0 +1,26 @@
//
// BenchmarkResults.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Structure containing comprehensive benchmark results for display and sharing.
* Aggregates test results, memory usage, and metadata for result presentation.
*/
struct BenchmarkResults {
let modelDisplayName: String
let maxMemoryKb: Int64
let testResults: [TestInstance]
let timestamp: String
init(modelDisplayName: String, maxMemoryKb: Int64, testResults: [TestInstance], timestamp: String) {
self.modelDisplayName = modelDisplayName
self.maxMemoryKb = maxMemoryKb
self.testResults = testResults
self.timestamp = timestamp
}
}

View File

@ -0,0 +1,28 @@
//
// BenchmarkStatistics.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Structure containing comprehensive statistical analysis of benchmark results.
* Aggregates performance metrics, configuration details, and test summary information.
*/
struct BenchmarkStatistics {
let configText: String
let prefillStats: SpeedStatistics?
let decodeStats: SpeedStatistics?
let totalTokensProcessed: Int
let totalTests: Int
static let empty = BenchmarkStatistics(
configText: "",
prefillStats: nil,
decodeStats: nil,
totalTokensProcessed: 0,
totalTests: 0
)
}

View File

@ -0,0 +1,142 @@
//
// DeviceInfoHelper.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
import UIKit
/**
* Helper class for retrieving device information including model identification
* and system details. Provides device-specific information for benchmark results.
*/
class DeviceInfoHelper {
static let shared = DeviceInfoHelper()
private init() {}
/// Gets the device model identifier (e.g., "iPhone14,7")
func getDeviceIdentifier() -> String {
var systemInfo = utsname()
uname(&systemInfo)
let machineMirror = Mirror(reflecting: systemInfo.machine)
let identifier = machineMirror.children.reduce("") { identifier, element in
guard let value = element.value as? Int8, value != 0 else { return identifier }
return identifier + String(UnicodeScalar(UInt8(value)))
}
return identifier
}
/// Gets the user-friendly device name (e.g., "iPhone 13 mini")
func getDeviceModelName() -> String {
let identifier = getDeviceIdentifier()
return mapIdentifierToModelName(identifier)
}
/// Gets detailed device information including model and system version
func getDeviceInfo() -> String {
let device = UIDevice.current
let systemVersion = device.systemVersion
let modelName = getDeviceModelName()
return "\(modelName), iOS \(systemVersion)"
}
private func mapIdentifierToModelName(_ identifier: String) -> String {
// iPhone mappings
let iPhoneMappings: [String: String] = [
// iPhone 13 series
"iPhone14,4": "iPhone 13 mini",
"iPhone14,5": "iPhone 13",
"iPhone14,2": "iPhone 13 Pro",
"iPhone14,3": "iPhone 13 Pro Max",
// iPhone 14 series
"iPhone14,7": "iPhone 14",
"iPhone14,8": "iPhone 14 Plus",
"iPhone15,2": "iPhone 14 Pro",
"iPhone15,3": "iPhone 14 Pro Max",
// iPhone 15 series
"iPhone15,4": "iPhone 15",
"iPhone15,5": "iPhone 15 Plus",
"iPhone16,1": "iPhone 15 Pro",
"iPhone16,2": "iPhone 15 Pro Max",
// iPhone 16 series
"iPhone17,1": "iPhone 16",
"iPhone17,2": "iPhone 16 Plus",
"iPhone17,3": "iPhone 16 Pro",
"iPhone17,4": "iPhone 16 Pro Max",
// iPhone SE series
"iPhone12,8": "iPhone SE (2nd generation)",
"iPhone14,6": "iPhone SE (3rd generation)",
// Older iPhones
"iPhone13,1": "iPhone 12 mini",
"iPhone13,2": "iPhone 12",
"iPhone13,3": "iPhone 12 Pro",
"iPhone13,4": "iPhone 12 Pro Max",
"iPhone12,1": "iPhone 11",
"iPhone12,3": "iPhone 11 Pro",
"iPhone12,5": "iPhone 11 Pro Max",
]
// iPad mappings
let iPadMappings: [String: String] = [
// iPad Pro 12.9-inch
"iPad13,8": "iPad Pro (12.9-inch) (5th generation)",
"iPad13,9": "iPad Pro (12.9-inch) (5th generation)",
"iPad13,10": "iPad Pro (12.9-inch) (5th generation)",
"iPad13,11": "iPad Pro (12.9-inch) (5th generation)",
"iPad14,5": "iPad Pro (12.9-inch) (6th generation)",
"iPad14,6": "iPad Pro (12.9-inch) (6th generation)",
// iPad Pro 11-inch
"iPad13,4": "iPad Pro (11-inch) (3rd generation)",
"iPad13,5": "iPad Pro (11-inch) (3rd generation)",
"iPad13,6": "iPad Pro (11-inch) (3rd generation)",
"iPad13,7": "iPad Pro (11-inch) (3rd generation)",
"iPad14,3": "iPad Pro (11-inch) (4th generation)",
"iPad14,4": "iPad Pro (11-inch) (4th generation)",
// iPad Air
"iPad13,1": "iPad Air (4th generation)",
"iPad13,2": "iPad Air (4th generation)",
"iPad13,16": "iPad Air (5th generation)",
"iPad13,17": "iPad Air (5th generation)",
// iPad mini
"iPad14,1": "iPad mini (6th generation)",
"iPad14,2": "iPad mini (6th generation)",
// iPad (regular)
"iPad12,1": "iPad (9th generation)",
"iPad12,2": "iPad (9th generation)",
"iPad13,18": "iPad (10th generation)",
"iPad13,19": "iPad (10th generation)",
]
// Try iPhone mappings first
if let modelName = iPhoneMappings[identifier] {
return modelName
}
// Try iPad mappings
if let modelName = iPadMappings[identifier] {
return modelName
}
// Check for simulator
if identifier == "x86_64" || identifier == "i386" {
return "Simulator"
}
// Return raw identifier if no mapping found
return identifier
}
}

View File

@ -0,0 +1,34 @@
//
// ModelItem.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Structure representing a model item with download state information.
* Used for tracking model availability and download progress in the benchmark interface.
*/
struct ModelItem: Identifiable, Equatable {
let id = UUID()
let modelId: String
let displayName: String
let isLocal: Bool
let localPath: String?
let size: Int64?
let downloadState: DownloadState
enum DownloadState: Equatable {
case notStarted
case downloading(progress: Double)
case completed
case failed(error: String)
case paused
}
static func == (lhs: ModelItem, rhs: ModelItem) -> Bool {
return lhs.modelId == rhs.modelId
}
}

View File

@ -0,0 +1,31 @@
//
// ModelListManager.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Manager class for integrating with ModelListViewModel to provide
* downloaded models for benchmark testing.
*/
class ModelListManager {
static let shared = ModelListManager()
private let modelListViewModel = ModelListViewModel()
private init() {}
/// Loads available models, filtering for downloaded models suitable for benchmarking
/// - Returns: Array of downloaded ModelInfo objects
/// - Throws: Error if model loading fails
func loadModels() async throws -> [ModelInfo] {
// Ensure models are loaded from the view model
await modelListViewModel.fetchModels()
// Return only downloaded models that are available for benchmark
return modelListViewModel.models.filter { $0.isDownloaded }
}
}

View File

@ -0,0 +1,34 @@
//
// ProgressType.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Enumeration representing different stages of benchmark execution progress.
* Used to track and display the current state of benchmark operations.
*/
enum ProgressType: Int, CaseIterable {
case unknown = 0
case initializing
case warmingUp
case runningTest
case processingResults
case completed
case stopping
var description: String {
switch self {
case .unknown: return "Unknown"
case .initializing: return "Initializing benchmark..."
case .warmingUp: return "Warming up..."
case .runningTest: return "Running test"
case .processingResults: return "Processing results..."
case .completed: return "All tests completed"
case .stopping: return "Stopping benchmark..."
}
}
}

View File

@ -0,0 +1,32 @@
//
// RuntimeParameters.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Configuration parameters for benchmark runtime environment.
* Defines hardware and execution settings for benchmark tests.
*/
struct RuntimeParameters {
let backends: [Int]
let threads: [Int]
let useMmap: Bool
let power: [Int]
let precision: [Int]
let memory: [Int]
let dynamicOption: [Int]
static let `default` = RuntimeParameters(
backends: [0], // CPU
threads: [4],
useMmap: false,
power: [0],
precision: [2], // Low precision
memory: [2], // Low memory
dynamicOption: [0]
)
}

View File

@ -0,0 +1,24 @@
//
// SpeedStatistics.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Structure containing statistical analysis of benchmark speed metrics.
* Provides average, standard deviation, and descriptive label for performance data.
*/
struct SpeedStatistics {
let average: Double
let stdev: Double
let label: String
init(average: Double, stdev: Double, label: String) {
self.average = average
self.stdev = stdev
self.label = label
}
}

View File

@ -0,0 +1,71 @@
//
// TestInstance.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
import Combine
/**
* Observable class representing a single benchmark test instance.
* Contains test configuration parameters and stores timing results.
*/
class TestInstance: ObservableObject, Identifiable {
let id = UUID()
let modelConfigFile: String
let modelType: String
let modelSize: Int64
let threads: Int
let useMmap: Bool
let nPrompt: Int
let nGenerate: Int
let backend: Int
let precision: Int
let power: Int
let memory: Int
let dynamicOption: Int
@Published var prefillUs: [Int64] = []
@Published var decodeUs: [Int64] = []
@Published var samplesUs: [Int64] = []
init(modelConfigFile: String,
modelType: String,
modelSize: Int64 = 0,
threads: Int,
useMmap: Bool,
nPrompt: Int,
nGenerate: Int,
backend: Int,
precision: Int,
power: Int,
memory: Int,
dynamicOption: Int) {
self.modelConfigFile = modelConfigFile
self.modelType = modelType
self.modelSize = modelSize
self.threads = threads
self.useMmap = useMmap
self.nPrompt = nPrompt
self.nGenerate = nGenerate
self.backend = backend
self.precision = precision
self.power = power
self.memory = memory
self.dynamicOption = dynamicOption
}
/// Calculates tokens per second from timing data
/// - Parameters:
/// - tokens: Number of tokens processed
/// - timesUs: Array of timing measurements in microseconds
/// - Returns: Array of tokens per second calculations
func getTokensPerSecond(tokens: Int, timesUs: [Int64]) -> [Double] {
return timesUs.compactMap { timeUs in
guard timeUs > 0 else { return 0.0 }
return Double(tokens) * 1_000_000.0 / Double(timeUs)
}
}
}

View File

@ -0,0 +1,30 @@
//
// TestParameters.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
/**
* Configuration parameters for benchmark test execution.
* Defines test scenarios including prompt sizes, generation lengths, and repetition counts.
*/
struct TestParameters {
let nPrompt: [Int]
let nGenerate: [Int]
let nPrompGen: [(Int, Int)]
let nRepeat: [Int]
let kvCache: String
let loadTime: String
static let `default` = TestParameters(
nPrompt: [256, 512],
nGenerate: [64, 128],
nPrompGen: [(256, 64), (512, 128)],
nRepeat: [3], // Reduced for mobile
kvCache: "false", // llama-bench style test by default
loadTime: "false"
)
}

View File

@ -0,0 +1,414 @@
//
// BenchmarkService.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
import Combine
/**
* Protocol defining callback methods for benchmark execution events.
* Provides progress updates, completion notifications, and error handling.
*/
protocol BenchmarkCallback: AnyObject {
func onProgress(_ progress: BenchmarkProgress)
func onComplete(_ result: BenchmarkResult)
func onBenchmarkError(_ errorCode: Int, _ message: String)
}
/**
* Singleton service class responsible for managing benchmark operations.
* Coordinates with LLMInferenceEngineWrapper to execute performance tests
* and provides real-time progress updates through callback mechanisms.
*/
class BenchmarkService: ObservableObject {
// MARK: - Singleton & Properties
static let shared = BenchmarkService()
@Published private(set) var isRunning = false
private var shouldStop = false
private var currentTask: Task<Void, Never>?
// Real LLM inference engine - using actual MNN LLM wrapper
private var llmEngine: LLMInferenceEngineWrapper?
private var currentModelId: String?
private init() {}
// MARK: - Public Interface
/// Initiates benchmark execution with specified parameters and callback handler
/// - Parameters:
/// - modelId: Identifier for the model to benchmark
/// - callback: Callback handler for progress and completion events
/// - runtimeParams: Runtime configuration parameters
/// - testParams: Test scenario parameters
func runBenchmark(
modelId: String,
callback: BenchmarkCallback,
runtimeParams: RuntimeParameters = .default,
testParams: TestParameters = .default
) {
guard !isRunning else {
callback.onBenchmarkError(BenchmarkErrorCode.benchmarkRunning.rawValue, "Benchmark is already running")
return
}
guard let engine = llmEngine, engine.isModelReady() else {
callback.onBenchmarkError(BenchmarkErrorCode.modelNotInitialized.rawValue, "Model is not initialized or not ready")
return
}
isRunning = true
shouldStop = false
currentTask = Task {
await performBenchmark(
engine: engine,
modelId: modelId,
callback: callback,
runtimeParams: runtimeParams,
testParams: testParams
)
}
}
/// Stops the currently running benchmark operation
func stopBenchmark() {
shouldStop = true
llmEngine?.stopBenchmark()
currentTask?.cancel()
isRunning = false
}
/// Checks if the model is properly initialized and ready for benchmarking
/// - Returns: True if model is ready, false otherwise
func isModelInitialized() -> Bool {
return llmEngine != nil && llmEngine!.isModelReady()
}
/// Initializes a model for benchmark testing
/// - Parameters:
/// - modelId: Identifier for the model
/// - modelPath: File system path to the model
/// - Returns: True if initialization succeeded, false otherwise
func initializeModel(modelId: String, modelPath: String) async -> Bool {
return await withCheckedContinuation { continuation in
// Release existing engine if any
llmEngine = nil
currentModelId = nil
// Create new LLM inference engine
llmEngine = LLMInferenceEngineWrapper(modelPath: modelPath) { success in
if success {
self.currentModelId = modelId
print("BenchmarkService: Model \(modelId) initialized successfully")
} else {
self.llmEngine = nil
print("BenchmarkService: Failed to initialize model \(modelId)")
}
continuation.resume(returning: success)
}
}
}
/// Retrieves information about the currently loaded model
/// - Returns: Model information string, or nil if no model is loaded
func getModelInfo() -> String? {
guard let modelId = currentModelId else { return nil }
return "Model: \(modelId), Engine: MNN LLM"
}
/// Releases the current model and frees associated resources
func releaseModel() {
llmEngine = nil
currentModelId = nil
}
// MARK: - Benchmark Execution
/// Performs the actual benchmark execution with progress tracking
/// - Parameters:
/// - engine: LLM inference engine instance
/// - modelId: Model identifier
/// - callback: Progress and completion callback handler
/// - runtimeParams: Runtime configuration
/// - testParams: Test parameters
private func performBenchmark(
engine: LLMInferenceEngineWrapper,
modelId: String,
callback: BenchmarkCallback,
runtimeParams: RuntimeParameters,
testParams: TestParameters
) async {
do {
let instances = generateTestInstances(runtimeParams: runtimeParams, testParams: testParams)
var completedInstances = 0
let totalInstances = instances.count
for instance in instances {
if shouldStop {
await MainActor.run {
callback.onBenchmarkError(BenchmarkErrorCode.benchmarkStopped.rawValue, "Benchmark stopped by user")
self.isRunning = false
}
return
}
// Create TestInstance for current configuration
let testInstance = TestInstance(
modelConfigFile: instance.configPath,
modelType: modelId,
modelSize: 0, // Will be calculated if needed
threads: instance.threads,
useMmap: instance.useMmap,
nPrompt: instance.nPrompt,
nGenerate: instance.nGenerate,
backend: instance.backend,
precision: instance.precision,
power: instance.power,
memory: instance.memory,
dynamicOption: instance.dynamicOption
)
// Update overall progress
let progress = (completedInstances * 100) / totalInstances
let statusMsg = "Running test \(completedInstances + 1)/\(totalInstances): pp\(instance.nPrompt)+tg\(instance.nGenerate)"
await MainActor.run {
callback.onProgress(BenchmarkProgress(
progress: progress,
statusMessage: statusMsg,
progressType: .runningTest,
currentIteration: completedInstances + 1,
totalIterations: totalInstances,
nPrompt: instance.nPrompt,
nGenerate: instance.nGenerate
))
}
// Execute benchmark using LLMInferenceEngineWrapper
let result = await runOfficialBenchmark(
engine: engine,
instance: instance,
testInstance: testInstance,
progressCallback: { progress in
await MainActor.run {
callback.onProgress(progress)
}
}
)
if result.success {
completedInstances += 1
// Only call onComplete for the last test instance
if completedInstances == totalInstances {
await MainActor.run {
callback.onComplete(result)
}
}
} else {
await MainActor.run {
callback.onBenchmarkError(BenchmarkErrorCode.testInstanceFailed.rawValue, result.errorMessage ?? "Test failed")
self.isRunning = false
}
return
}
}
await MainActor.run {
self.isRunning = false
}
} catch {
await MainActor.run {
callback.onBenchmarkError(BenchmarkErrorCode.nativeError.rawValue, error.localizedDescription)
self.isRunning = false
}
}
}
/// Executes a single benchmark test using the official MNN LLM benchmark interface
/// - Parameters:
/// - engine: LLM inference engine
/// - instance: Test configuration
/// - testInstance: Test instance to populate with results
/// - progressCallback: Callback for progress updates
/// - Returns: Benchmark result with success status and timing data
private func runOfficialBenchmark(
engine: LLMInferenceEngineWrapper,
instance: TestConfig,
testInstance: TestInstance,
progressCallback: @escaping (BenchmarkProgress) async -> Void
) async -> BenchmarkResult {
return await withCheckedContinuation { continuation in
var hasResumed = false
engine.runOfficialBenchmark(
withBackend: instance.backend,
threads: instance.threads,
useMmap: instance.useMmap,
power: instance.power,
precision: instance.precision,
memory: instance.memory,
dynamicOption: instance.dynamicOption,
nPrompt: instance.nPrompt,
nGenerate: instance.nGenerate,
nRepeat: instance.nRepeat,
kvCache: instance.kvCache,
progressCallback: { [self] progressInfo in
// Convert Objective-C BenchmarkProgressInfo to Swift BenchmarkProgress
let swiftProgress = BenchmarkProgress(
progress: Int(progressInfo.progress),
statusMessage: progressInfo.statusMessage,
progressType: convertProgressType(progressInfo.progressType),
currentIteration: Int(progressInfo.currentIteration),
totalIterations: Int(progressInfo.totalIterations),
nPrompt: Int(progressInfo.nPrompt),
nGenerate: Int(progressInfo.nGenerate),
runTimeSeconds: progressInfo.runTimeSeconds,
prefillTimeSeconds: progressInfo.prefillTimeSeconds,
decodeTimeSeconds: progressInfo.decodeTimeSeconds,
prefillSpeed: progressInfo.prefillSpeed,
decodeSpeed: progressInfo.decodeSpeed
)
Task {
await progressCallback(swiftProgress)
}
},
errorCallback: { errorMessage in
if !hasResumed {
hasResumed = true
let result = BenchmarkResult(
testInstance: testInstance,
success: false,
errorMessage: errorMessage
)
continuation.resume(returning: result)
}
},
iterationCompleteCallback: { detailedStats in
// Log detailed stats if needed
print("Benchmark iteration complete: \(detailedStats)")
},
completeCallback: { benchmarkResult in
if !hasResumed {
hasResumed = true
// Update test instance with timing results
testInstance.prefillUs = benchmarkResult.prefillTimesUs.map { $0.int64Value }
testInstance.decodeUs = benchmarkResult.decodeTimesUs.map { $0.int64Value }
testInstance.samplesUs = benchmarkResult.sampleTimesUs.map { $0.int64Value }
let result = BenchmarkResult(
testInstance: testInstance,
success: benchmarkResult.success,
errorMessage: benchmarkResult.errorMessage
)
continuation.resume(returning: result)
}
}
)
}
}
// MARK: - Helper Methods & Configuration
/// Converts Objective-C progress type to Swift enum
/// - Parameter objcType: Objective-C progress type
/// - Returns: Corresponding Swift ProgressType
private func convertProgressType(_ objcType: BenchmarkProgressType) -> ProgressType {
switch objcType {
case .unknown:
return .unknown
case .initializing:
return .initializing
case .warmingUp:
return .warmingUp
case .runningTest:
return .runningTest
case .processingResults:
return .processingResults
case .completed:
return .completed
case .stopping:
return .stopping
@unknown default:
return .unknown
}
}
/// Generates test instances by combining runtime and test parameters
/// - Parameters:
/// - runtimeParams: Runtime configuration parameters
/// - testParams: Test scenario parameters
/// - Returns: Array of test configurations for execution
private func generateTestInstances(
runtimeParams: RuntimeParameters,
testParams: TestParameters
) -> [TestConfig] {
var instances: [TestConfig] = []
for backend in runtimeParams.backends {
for thread in runtimeParams.threads {
for power in runtimeParams.power {
for precision in runtimeParams.precision {
for memory in runtimeParams.memory {
for dynamicOption in runtimeParams.dynamicOption {
for repeatCount in testParams.nRepeat {
for (nPrompt, nGenerate) in testParams.nPrompGen {
instances.append(TestConfig(
configPath: "", // Will be set based on model
backend: backend,
threads: thread,
useMmap: runtimeParams.useMmap,
power: power,
precision: precision,
memory: memory,
dynamicOption: dynamicOption,
nPrompt: nPrompt,
nGenerate: nGenerate,
nRepeat: repeatCount,
kvCache: testParams.kvCache == "true"
))
}
}
}
}
}
}
}
}
return instances
}
}
// MARK: - Test Configuration
/**
* Structure containing configuration parameters for a single benchmark test.
* Combines runtime settings and test parameters into a complete test specification.
*/
struct TestConfig {
let configPath: String
let backend: Int
let threads: Int
let useMmap: Bool
let power: Int
let precision: Int
let memory: Int
let dynamicOption: Int
let nPrompt: Int
let nGenerate: Int
let nRepeat: Int
let kvCache: Bool
}

View File

@ -0,0 +1,444 @@
//
// BenchmarkViewModel.swift
// MNNLLMiOS
//
// Created by () on 2025/7/10.
//
import Foundation
import SwiftUI
import Combine
/**
* ViewModel for managing benchmark operations including model selection, test execution,
* progress tracking, and result management. Handles communication with BenchmarkService
* and provides UI state management for the benchmark interface.
*/
@MainActor
class BenchmarkViewModel: ObservableObject {
// MARK: - Published Properties
@Published var isLoading = false
@Published var isRunning = false
@Published var showProgressBar = false
@Published var showResults = false
@Published var showError = false
@Published var selectedModel: ModelInfo?
@Published var availableModels: [ModelInfo] = []
@Published var currentProgress: BenchmarkProgress?
@Published var benchmarkResults: BenchmarkResults?
@Published var errorMessage: String = ""
@Published var statusMessage: String = ""
@Published var startButtonText = String(localized: "Start Test")
@Published var isStartButtonEnabled = true
// MARK: - Private Properties
private let benchmarkService = BenchmarkService.shared
private let resultsHelper = BenchmarkResultsHelper.shared
private var cancellables = Set<AnyCancellable>()
// Model list manager for getting local models
private let modelListManager = ModelListManager.shared
// MARK: - Initialization & Setup
init() {
setupBindings()
loadAvailableModels()
}
/// Sets up reactive bindings between service and view model
private func setupBindings() {
benchmarkService.$isRunning
.receive(on: DispatchQueue.main)
.assign(to: \.isRunning, on: self)
.store(in: &cancellables)
// Update button text based on running state
benchmarkService.$isRunning
.receive(on: DispatchQueue.main)
.map { isRunning in
isRunning ? String(localized: "Stop Test") : String(localized: "Start Test")
}
.assign(to: \.startButtonText, on: self)
.store(in: &cancellables)
}
/// Loads available models from ModelListManager, filtering for downloaded models only
private func loadAvailableModels() {
Task {
isLoading = true
do {
// Get all models from ModelListManager
let allModels = try await modelListManager.loadModels()
// Filter only downloaded models that are available locally
availableModels = allModels.filter { model in
model.isDownloaded && model.localPath != nil
}
print("BenchmarkViewModel: Loaded \(availableModels.count) available local models")
} catch {
showErrorMessage("Failed to load models: \(error.localizedDescription)")
}
isLoading = false
}
}
// MARK: - Public Action Handlers
/// Handles start/stop benchmark button taps
func onStartBenchmarkTapped() {
if !isRunning {
startBenchmark()
} else {
showStopConfirmationAlert()
}
}
/// Handles benchmark stop confirmation
func onStopBenchmarkTapped() {
stopBenchmark()
}
/// Handles model selection from dropdown
func onModelSelected(_ model: ModelInfo) {
selectedModel = model
}
/// Handles result deletion and cleanup
func onDeleteResultTapped() {
benchmarkResults = nil
showResults = false
hideStatus()
// Release model to free memory
benchmarkService.releaseModel()
}
/// Placeholder for future result submission functionality
func onSubmitResultTapped() {
// Implementation for submitting results (if needed)
// This could involve sharing or uploading results
}
// MARK: - Benchmark Execution
/// Initiates benchmark test with selected model and configured parameters
private func startBenchmark() {
guard let model = selectedModel else {
showErrorMessage("Please select a model first")
return
}
guard model.isDownloaded else {
showErrorMessage("Selected model is not downloaded or path is invalid")
return
}
onBenchmarkStarted()
Task {
// Initialize model if needed
let initialized = await benchmarkService.initializeModel(
modelId: model.id,
modelPath: model.localPath
)
guard initialized else {
showErrorMessage("Failed to initialize model")
resetUIState()
return
}
// Start memory monitoring
MemoryMonitor.shared.start()
// Start benchmark with optimized parameters for mobile devices
benchmarkService.runBenchmark(
modelId: model.id,
callback: self,
runtimeParams: createRuntimeParameters(),
testParams: createTestParameters()
)
}
}
/// Creates runtime parameters optimized for iOS devices
private func createRuntimeParameters() -> RuntimeParameters {
return RuntimeParameters(
backends: [0], // CPU backend
threads: [4], // 4 threads for most iOS devices
useMmap: false, // Memory mapping disabled for iOS
power: [0], // Normal power mode
precision: [2], // Low precision for better performance
memory: [2], // Low memory usage
dynamicOption: [0] // No dynamic optimization
)
}
/// Creates test parameters suitable for mobile benchmarking
private func createTestParameters() -> TestParameters {
return TestParameters(
nPrompt: [256, 512], // Smaller prompt sizes for mobile
nGenerate: [64, 128], // Smaller generation sizes
nPrompGen: [(256, 64), (512, 128)], // Combined test cases
nRepeat: [3], // Fewer repetitions for faster testing
kvCache: "false", // Disable KV cache by default
loadTime: "false"
)
}
/// Stops the currently running benchmark
private func stopBenchmark() {
updateStatus("Stopping benchmark...")
benchmarkService.stopBenchmark()
MemoryMonitor.shared.stop()
}
// MARK: - UI State Management
/// Updates UI state when benchmark starts
private func onBenchmarkStarted() {
isStartButtonEnabled = true
showProgressBar = true
showResults = false
updateStatus("Initializing benchmark...")
}
/// Resets UI to initial state
private func resetUIState() {
isStartButtonEnabled = true
showProgressBar = false
hideStatus()
showResults = false
MemoryMonitor.shared.stop()
}
/// Updates status message display
private func updateStatus(_ message: String) {
statusMessage = message
}
/// Hides status message
private func hideStatus() {
statusMessage = ""
}
/// Shows error message alert
private func showErrorMessage(_ message: String) {
errorMessage = message
showError = true
}
/// Placeholder for stop confirmation alert (handled in View)
private func showStopConfirmationAlert() {
// This will be handled in the View with an alert
}
/// Formats progress messages with appropriate status text based on progress type
private func formatProgressMessage(_ progress: BenchmarkProgress) -> BenchmarkProgress {
let formattedMessage: String
switch progress.progressType {
case .initializing:
formattedMessage = "Initializing benchmark..."
case .warmingUp:
formattedMessage = "Warming up..."
case .runningTest:
formattedMessage = "Running test \(progress.currentIteration)/\(progress.totalIterations)"
case .processingResults:
formattedMessage = "Processing results..."
case .completed:
formattedMessage = "All tests completed"
case .stopping:
formattedMessage = "Stopping benchmark..."
default:
formattedMessage = progress.statusMessage
}
return BenchmarkProgress(
progress: progress.progress,
statusMessage: formattedMessage,
progressType: progress.progressType,
currentIteration: progress.currentIteration,
totalIterations: progress.totalIterations,
nPrompt: progress.nPrompt,
nGenerate: progress.nGenerate,
runTimeSeconds: progress.runTimeSeconds,
prefillTimeSeconds: progress.prefillTimeSeconds,
decodeTimeSeconds: progress.decodeTimeSeconds,
prefillSpeed: progress.prefillSpeed,
decodeSpeed: progress.decodeSpeed
)
}
}
// MARK: - BenchmarkCallback Implementation
extension BenchmarkViewModel: BenchmarkCallback {
/// Handles progress updates from benchmark service
func onProgress(_ progress: BenchmarkProgress) {
let formattedProgress = formatProgressMessage(progress)
currentProgress = formattedProgress
updateStatus(formattedProgress.statusMessage)
}
/// Handles benchmark completion with results processing
func onComplete(_ result: BenchmarkResult) {
guard let model = selectedModel else { return }
updateStatus("Processing results...")
// Create comprehensive benchmark results
let results = BenchmarkResults(
modelDisplayName: model.modelName,
maxMemoryKb: MemoryMonitor.shared.getMaxMemoryKb(),
testResults: [result.testInstance],
timestamp: DateFormatter.benchmarkTimestamp.string(from: Date())
)
benchmarkResults = results
showResults = true
// Only stop memory monitoring if benchmark is no longer running (all tests completed)
if !isRunning {
// Stop memory monitoring
MemoryMonitor.shared.stop()
}
// Always hide status after processing results
hideStatus()
print("BenchmarkViewModel: Benchmark completed successfully for model: \(model.modelName)")
}
/// Handles benchmark errors with user-friendly error messages
func onBenchmarkError(_ errorCode: Int, _ message: String) {
let errorCodeName = BenchmarkErrorCode(rawValue: errorCode)?.description ?? "Unknown"
showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)")
resetUIState()
print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)")
}
}
// MARK: - Memory Monitoring
/**
* Singleton class for monitoring memory usage during benchmark execution.
* Tracks current and peak memory consumption using system APIs.
*/
class MemoryMonitor: ObservableObject {
static let shared = MemoryMonitor()
@Published private(set) var currentMemoryKb: Int64 = 0
private var maxMemoryKb: Int64 = 0
private var isMonitoring = false
private var monitoringTask: Task<Void, Never>?
private init() {}
/// Starts continuous memory monitoring
func start() {
guard !isMonitoring else { return }
isMonitoring = true
maxMemoryKb = 0
monitoringTask = Task {
while isMonitoring && !Task.isCancelled {
await updateMemoryUsage()
try? await Task.sleep(nanoseconds: 500_000_000) // 0.5 seconds
}
}
}
/// Stops memory monitoring
func stop() {
isMonitoring = false
monitoringTask?.cancel()
monitoringTask = nil
}
/// Resets memory tracking counters
func reset() {
maxMemoryKb = 0
currentMemoryKb = 0
}
/// Returns the maximum memory usage recorded during monitoring
func getMaxMemoryKb() -> Int64 {
return maxMemoryKb
}
/// Updates current memory usage and tracks maximum
@MainActor
private func updateMemoryUsage() {
let memoryUsage = getCurrentMemoryUsage()
currentMemoryKb = memoryUsage
maxMemoryKb = max(maxMemoryKb, memoryUsage)
}
/// Gets current memory usage from system using mach task info
private func getCurrentMemoryUsage() -> Int64 {
var info = mach_task_basic_info()
var count = mach_msg_type_number_t(MemoryLayout<mach_task_basic_info>.size) / 4
let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) {
$0.withMemoryRebound(to: integer_t.self, capacity: 1) {
task_info(mach_task_self_,
task_flavor_t(MACH_TASK_BASIC_INFO),
$0,
&count)
}
}
if kerr == KERN_SUCCESS {
return Int64(info.resident_size) / 1024 // Convert to KB
} else {
return 0
}
}
}
// MARK: - Extensions
/// Extension providing user-friendly descriptions for benchmark error codes
extension BenchmarkErrorCode {
var description: String {
switch self {
case .benchmarkFailedUnknown:
return "Unknown Error"
case .testInstanceFailed:
return "Test Failed"
case .modelNotInitialized:
return "Model Not Ready"
case .benchmarkRunning:
return "Already Running"
case .benchmarkStopped:
return "Stopped"
case .nativeError:
return "Native Error"
case .modelError:
return "Model Error"
}
}
}
/// Extension providing formatted timestamp for benchmark results
extension DateFormatter {
static let benchmarkTimestamp: DateFormatter = {
let formatter = DateFormatter()
formatter.dateFormat = "yyyy/M/dd HH:mm:ss"
return formatter
}()
}

View File

@ -0,0 +1,58 @@
//
// MetricCard.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Reusable metric display card component.
* Shows performance metrics with icon, title, and value in a compact format.
*/
struct MetricCard: View {
let title: String
let value: String
let icon: String
var body: some View {
VStack(alignment: .leading, spacing: 6) {
HStack(spacing: 6) {
Image(systemName: icon)
.font(.caption)
.foregroundColor(.benchmarkAccent)
Text(title)
.font(.caption)
.foregroundColor(.benchmarkSecondary)
.lineLimit(1)
}
Text(value)
.font(.system(size: 14, weight: .semibold))
.foregroundColor(.primary)
.lineLimit(1)
}
.frame(maxWidth: .infinity, alignment: .leading)
.padding(.horizontal, 12)
.padding(.vertical, 8)
.background(
RoundedRectangle(cornerRadius: 8)
.fill(Color.benchmarkAccent.opacity(0.05))
.overlay(
RoundedRectangle(cornerRadius: 8)
.stroke(Color.benchmarkAccent.opacity(0.1), lineWidth: 1)
)
)
}
}
#Preview {
HStack(spacing: 12) {
MetricCard(title: "Runtime", value: "2.456s", icon: "clock")
MetricCard(title: "Speed", value: "109.8 t/s", icon: "speedometer")
MetricCard(title: "Memory", value: "1.2 GB", icon: "memorychip")
}
.padding()
}

View File

@ -0,0 +1,241 @@
//
// ModelSelectionCard.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Reusable model selection card component for benchmark interface.
* Provides dropdown menu for model selection and start/stop controls.
*/
struct ModelSelectionCard: View {
@ObservedObject var viewModel: BenchmarkViewModel
@Binding var showStopConfirmation: Bool
var body: some View {
VStack(alignment: .leading, spacing: 16) {
HStack {
Text("Select Model")
.font(.title3)
.fontWeight(.semibold)
.foregroundColor(.primary)
Spacer()
}
if viewModel.isLoading {
HStack {
ProgressView()
.scaleEffect(0.8)
Text("Loading models...")
.font(.subheadline)
.foregroundColor(.secondary)
}
.frame(maxWidth: .infinity, alignment: .leading)
} else {
modelDropdownMenu
}
startStopButton
statusMessages
}
.padding(20)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(Color.benchmarkCardBg)
.overlay(
RoundedRectangle(cornerRadius: 16)
.stroke(Color.benchmarkSuccess.opacity(0.3), lineWidth: 1)
)
)
}
// MARK: - Private Views
private var modelDropdownMenu: some View {
Menu {
if viewModel.availableModels.isEmpty {
Button("No models available") {
// Placeholder - no action
}
.disabled(true)
} else {
ForEach(viewModel.availableModels, id: \.id) { model in
Button(action: {
viewModel.onModelSelected(model)
}) {
HStack {
VStack(alignment: .leading, spacing: 2) {
Text(model.modelName)
.font(.system(size: 14, weight: .medium))
Text("Local")
.font(.caption)
.foregroundColor(.secondary)
}
}
}
}
}
} label: {
HStack(spacing: 16) {
VStack(alignment: .leading, spacing: 6) {
Text(viewModel.selectedModel?.modelName ?? String(localized: "Choose your AI model"))
.font(.system(size: 16, weight: .medium))
.foregroundColor(viewModel.isRunning ? .secondary : (viewModel.selectedModel != nil ? .primary : .benchmarkSecondary))
.lineLimit(1)
if let model = viewModel.selectedModel {
HStack(spacing: 8) {
HStack(spacing: 4) {
Circle()
.fill(Color.benchmarkSuccess)
.frame(width: 6, height: 6)
Text("Ready")
.font(.caption)
.foregroundColor(.benchmarkSuccess)
}
if let size = model.cachedSize {
Text("\(formatBytes(size))")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
} else {
Text("Tap to select a model for testing")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
Spacer()
Image(systemName: "chevron.down")
.font(.system(size: 14, weight: .medium))
.foregroundColor(viewModel.isRunning ? .secondary : .benchmarkSecondary)
.rotationEffect(.degrees(0))
}
.padding(20)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(Color.benchmarkCardBg)
.overlay(
RoundedRectangle(cornerRadius: 16)
.stroke(
viewModel.isRunning ?
Color.gray.opacity(0.1) :
(viewModel.selectedModel != nil ?
Color.benchmarkAccent.opacity(0.3) :
Color.gray.opacity(0.2)),
lineWidth: 1
)
))
}
.disabled(viewModel.isRunning)
}
private var startStopButton: some View {
Button(action: {
if viewModel.startButtonText.contains("Stop") {
showStopConfirmation = true
} else {
viewModel.onStartBenchmarkTapped()
}
}) {
HStack(spacing: 12) {
ZStack {
Circle()
.fill(Color.white.opacity(0.2))
.frame(width: 32, height: 32)
if viewModel.isRunning && viewModel.startButtonText.contains("Stop") {
ProgressView()
.progressViewStyle(CircularProgressViewStyle(tint: .white))
.scaleEffect(0.7)
} else {
Image(systemName: viewModel.startButtonText.contains("Stop") ? "stop.fill" : "play.fill")
.font(.system(size: 16, weight: .bold))
.foregroundColor(.white)
}
}
Text(viewModel.startButtonText)
.font(.system(size: 18, weight: .semibold))
.foregroundColor(.white)
Spacer()
if !viewModel.startButtonText.contains("Stop") {
Image(systemName: "arrow.right")
.font(.system(size: 16, weight: .semibold))
.foregroundColor(.white.opacity(0.8))
}
}
.frame(maxWidth: .infinity)
.padding(.horizontal, 24)
.padding(.vertical, 18)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(
viewModel.isStartButtonEnabled ?
(viewModel.startButtonText.contains("Stop") ?
LinearGradient(
colors: [Color.benchmarkError, Color.benchmarkError.opacity(0.8)],
startPoint: .leading,
endPoint: .trailing
) :
LinearGradient(
colors: [Color.benchmarkGradientStart, Color.benchmarkGradientEnd],
startPoint: .leading,
endPoint: .trailing
)) :
LinearGradient(
colors: [Color.gray, Color.gray.opacity(0.8)],
startPoint: .leading,
endPoint: .trailing
)
)
)
}
.disabled(!viewModel.isStartButtonEnabled || viewModel.selectedModel == nil)
.animation(.easeInOut(duration: 0.2), value: viewModel.startButtonText)
.animation(.easeInOut(duration: 0.2), value: viewModel.isStartButtonEnabled)
}
private var statusMessages: some View {
Group {
if viewModel.selectedModel == nil {
Text("Start benchmark after selecting your model")
.font(.caption)
.foregroundColor(.orange)
.padding(.horizontal, 16)
} else if viewModel.availableModels.isEmpty {
Text("No local models found. Please download a model first.")
.font(.caption)
.foregroundColor(.orange)
.padding(.horizontal, 16)
}
}
}
// MARK: - Helper Functions
private func formatBytes(_ bytes: Int64) -> String {
let formatter = ByteCountFormatter()
formatter.allowedUnits = [.useGB, .useMB]
formatter.countStyle = .file
return formatter.string(fromByteCount: bytes)
}
}
#Preview {
ModelSelectionCard(
viewModel: BenchmarkViewModel(),
showStopConfirmation: .constant(false)
)
.padding()
}

View File

@ -0,0 +1,117 @@
//
// EnhancedPerformanceMetricView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Enhanced performance metric display component.
* Shows detailed performance metrics with gradient backgrounds, icons, and custom colors.
*/
struct PerformanceMetricView: View {
let icon: String
let title: String
let value: String
let subtitle: String
let color: Color
var body: some View {
VStack(alignment: .center, spacing: 12) {
ZStack {
Circle()
.fill(
LinearGradient(
colors: [color.opacity(0.2), color.opacity(0.1)],
startPoint: .topLeading,
endPoint: .bottomTrailing
)
)
.frame(width: 50, height: 50)
Image(systemName: icon)
.font(.system(size: 25, weight: .semibold))
.foregroundColor(color)
}
VStack(alignment: .center, spacing: 2) {
Text(title)
.font(.subheadline)
.fontWeight(.medium)
.foregroundColor(.primary)
Text(subtitle)
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
Text(value)
.font(.title2)
.fontWeight(.bold)
.foregroundColor(color)
.multilineTextAlignment(.center)
.lineLimit(nil)
.fixedSize(horizontal: false, vertical: true)
}
.frame(maxWidth: .infinity, alignment: .center)
.padding(16)
.background(
RoundedRectangle(cornerRadius: 12)
.fill(
LinearGradient(
colors: [Color.benchmarkCardBg, color.opacity(0.02)],
startPoint: .topLeading,
endPoint: .bottomTrailing
)
)
.overlay(
RoundedRectangle(cornerRadius: 12)
.stroke(color.opacity(0.2), lineWidth: 1)
)
)
}
}
#Preview {
VStack(spacing: 16) {
HStack(spacing: 12) {
PerformanceMetricView(
icon: "speedometer",
title: "Prefill Speed",
value: "1024.5 t/s",
subtitle: "Tokens per second",
color: .benchmarkGradientStart
)
PerformanceMetricView(
icon: "gauge",
title: "Decode Speed",
value: "109.8 t/s",
subtitle: "Generation rate",
color: .benchmarkGradientEnd
)
}
HStack(spacing: 12) {
PerformanceMetricView(
icon: "memorychip",
title: "Memory Usage",
value: "1.2 GB",
subtitle: "Peak memory",
color: .benchmarkWarning
)
PerformanceMetricView(
icon: "clock",
title: "Total Time",
value: "2.456s",
subtitle: "Complete duration",
color: .benchmarkSuccess
)
}
}
.padding()
}

View File

@ -0,0 +1,187 @@
//
// ProgressCard.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Reusable progress tracking card component for benchmark interface.
* Displays test progress with detailed metrics and visual indicators.
*/
struct ProgressCard: View {
let progress: BenchmarkProgress?
var body: some View {
VStack(alignment: .leading, spacing: 20) {
if let progress = progress {
VStack(alignment: .leading, spacing: 16) {
progressHeader(progress)
progressBar(progress)
if progress.progressType == .runningTest && progress.totalIterations > 0 {
testDetails(progress)
}
}
} else {
fallbackProgress
}
}
.padding(20)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(Color.benchmarkCardBg)
.overlay(
RoundedRectangle(cornerRadius: 16)
.stroke(Color.benchmarkSuccess.opacity(0.3), lineWidth: 1)
)
)
}
// MARK: - Private Views
private func progressHeader(_ progress: BenchmarkProgress) -> some View {
HStack {
HStack(spacing: 12) {
ZStack {
Circle()
.fill(
LinearGradient(
colors: [Color.benchmarkAccent.opacity(0.2), Color.benchmarkGradientEnd.opacity(0.1)],
startPoint: .topLeading,
endPoint: .bottomTrailing
)
)
.frame(width: 40, height: 40)
Image(systemName: "chart.line.uptrend.xyaxis")
.font(.system(size: 18, weight: .semibold))
.foregroundColor(.benchmarkAccent)
}
VStack(alignment: .leading, spacing: 2) {
Text("Test Progress")
.font(.title3)
.fontWeight(.semibold)
.foregroundColor(.primary)
Text("Running performance tests")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
Spacer()
VStack(alignment: .trailing, spacing: 2) {
Text("\(progress.progress)%")
.font(.title2)
.fontWeight(.bold)
.foregroundColor(.benchmarkAccent)
Text("Complete")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
}
private func progressBar(_ progress: BenchmarkProgress) -> some View {
VStack(spacing: 8) {
ZStack(alignment: .leading) {
RoundedRectangle(cornerRadius: 8)
.fill(Color.gray.opacity(0.2))
.frame(height: 8)
RoundedRectangle(cornerRadius: 8)
.fill(
LinearGradient(
colors: [Color.benchmarkGradientStart, Color.benchmarkGradientEnd],
startPoint: .leading,
endPoint: .trailing
)
)
.frame(width: CGFloat(progress.progress) / 100 * UIScreen.main.bounds.width * 0.8, height: 8)
.animation(.easeInOut(duration: 0.3), value: progress.progress)
}
}
}
private func testDetails(_ progress: BenchmarkProgress) -> some View {
VStack(alignment: .leading, spacing: 12) {
// Test iteration info
HStack {
Image(systemName: "repeat")
.font(.caption)
.foregroundColor(.benchmarkAccent)
Text("Test \(progress.currentIteration) of \(progress.totalIterations)")
.font(.subheadline)
.fontWeight(.medium)
.foregroundColor(.primary)
Spacer()
Text("PP: \(progress.nPrompt) • TG: \(progress.nGenerate)")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
.padding(.horizontal, 8)
.padding(.vertical, 4)
.background(
RoundedRectangle(cornerRadius: 6)
.fill(Color.benchmarkAccent.opacity(0.1))
)
}
// Real-time performance metrics
if progress.runTimeSeconds > 0 {
VStack(spacing: 12) {
// Timing metrics
HStack(spacing: 12) {
MetricCard(title: "Runtime", value: String(format: "%.3fs", progress.runTimeSeconds), icon: "clock")
MetricCard(title: "Prefill", value: String(format: "%.3fs", progress.prefillTimeSeconds), icon: "arrow.up.circle")
MetricCard(title: "Decode", value: String(format: "%.3fs", progress.decodeTimeSeconds), icon: "arrow.down.circle")
}
// Speed metrics
HStack(spacing: 12) {
MetricCard(title: "Prefill Speed", value: String(format: "%.2f t/s", progress.prefillSpeed), icon: "speedometer")
MetricCard(title: "Decode Speed", value: String(format: "%.2f t/s", progress.decodeSpeed), icon: "gauge")
Spacer()
}
}
}
}
}
private var fallbackProgress: some View {
VStack(alignment: .leading, spacing: 8) {
Text("Progress")
.font(.headline)
ProgressView()
.progressViewStyle(LinearProgressViewStyle())
}
}
}
#Preview {
ProgressCard(
progress: BenchmarkProgress(
progress: 65,
statusMessage: "Running benchmark...",
progressType: .runningTest,
currentIteration: 3,
totalIterations: 5,
nPrompt: 128,
nGenerate: 256,
runTimeSeconds: 2.456,
prefillTimeSeconds: 0.123,
decodeTimeSeconds: 2.333,
prefillSpeed: 1024.5,
decodeSpeed: 109.8
)
)
.padding()
}

View File

@ -0,0 +1,311 @@
//
// ResultsCard.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Reusable results display card component for benchmark interface.
* Shows comprehensive benchmark results with performance metrics and statistics.
*/
struct ResultsCard: View {
let results: BenchmarkResults
var body: some View {
VStack(alignment: .leading, spacing: 20) {
resultsHeader
infoHeader
performanceMetrics
detailedStats
}
.padding(20)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(Color.benchmarkCardBg)
.overlay(
RoundedRectangle(cornerRadius: 16)
.stroke(Color.benchmarkSuccess.opacity(0.3), lineWidth: 1)
)
)
}
// MARK: - Private Views
private var infoHeader: some View {
let statistics = BenchmarkResultsHelper.shared.processTestResults(results.testResults)
return VStack(alignment: .leading, spacing: 8) {
Text(results.modelDisplayName)
.font(.headline)
Text(BenchmarkResultsHelper.shared.getDeviceInfo())
.font(.subheadline)
.foregroundColor(.secondary)
Text("Benchmark Config")
.font(.headline)
Text(statistics.configText)
.font(.subheadline)
.lineLimit(nil)
.fixedSize(horizontal: false, vertical: true)
.foregroundColor(.secondary)
}
}
private var resultsHeader: some View {
HStack {
HStack(spacing: 12) {
ZStack {
Circle()
.fill(
LinearGradient(
colors: [Color.benchmarkSuccess.opacity(0.2), Color.benchmarkSuccess.opacity(0.1)],
startPoint: .topLeading,
endPoint: .bottomTrailing
)
)
.frame(width: 40, height: 40)
Image(systemName: "chart.bar.fill")
.font(.system(size: 18, weight: .semibold))
.foregroundColor(.benchmarkSuccess)
}
VStack(alignment: .leading, spacing: 2) {
Text("Benchmark Results")
.font(.title3)
.fontWeight(.semibold)
.foregroundColor(.primary)
Text("Performance analysis complete")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
Spacer()
Button(action: {
shareResults()
}) {
VStack(alignment: .center, spacing: 2) {
Image(systemName: "square.and.arrow.up")
.font(.title2)
.foregroundColor(.benchmarkSuccess)
Text("Share")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
.buttonStyle(PlainButtonStyle())
}
}
private var performanceMetrics: some View {
let statistics = BenchmarkResultsHelper.shared.processTestResults(results.testResults)
return VStack(spacing: 16) {
HStack(spacing: 12) {
if let prefillStats = statistics.prefillStats {
PerformanceMetricView(
icon: "speedometer",
title: "Prefill Speed",
value: BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(prefillStats),
subtitle: "Tokens per second",
color: .benchmarkGradientStart
)
} else {
PerformanceMetricView(
icon: "speedometer",
title: "Prefill Speed",
value: "N/A",
subtitle: "Tokens per second",
color: .benchmarkGradientStart
)
}
if let decodeStats = statistics.decodeStats {
PerformanceMetricView(
icon: "gauge",
title: "Decode Speed",
value: BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(decodeStats),
subtitle: "Generation rate",
color: .benchmarkGradientEnd
)
} else {
PerformanceMetricView(
icon: "gauge",
title: "Decode Speed",
value: "N/A",
subtitle: "Generation rate",
color: .benchmarkGradientEnd
)
}
}
HStack(spacing: 12) {
let totalMemoryKb = BenchmarkResultsHelper.shared.getTotalSystemMemoryKb()
let memoryInfo = BenchmarkResultsHelper.shared.formatMemoryUsage(
maxMemoryKb: results.maxMemoryKb,
totalKb: totalMemoryKb
)
PerformanceMetricView(
icon: "memorychip",
title: "Memory Usage",
value: memoryInfo.valueText,
subtitle: "Peak memory",
color: .benchmarkWarning
)
PerformanceMetricView(
icon: "clock",
title: "Total Tokens",
value: "\(statistics.totalTokensProcessed)",
subtitle: "Complete duration",
color: .benchmarkSuccess
)
}
}
}
private var detailedStats: some View {
return VStack(alignment: .leading, spacing: 12) {
VStack(spacing: 8) {
HStack {
Text("Completed")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
Spacer()
Text(results.timestamp)
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
HStack {
Text("Powered By MNN")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
Spacer()
Text(verbatim: "https://github.com/alibaba/MNN")
.font(.caption)
.foregroundColor(.benchmarkSecondary)
}
}
.padding(.vertical, 8)
}
}
// MARK: - Helper Functions
/// Formats byte count into human-readable string
private func formatBytes(_ bytes: Int64) -> String {
let formatter = ByteCountFormatter()
formatter.allowedUnits = [.useKB, .useMB, .useGB]
formatter.countStyle = .file
return formatter.string(fromByteCount: bytes)
}
/// Initiates sharing of benchmark results through system share sheet
private func shareResults() {
let viewToRender = self.body.frame(width: 390) // Adjust width as needed
if let image = viewToRender.snapshot() {
presentShareSheet(activityItems: [image, formatResultsForSharing()])
} else {
presentShareSheet(activityItems: [formatResultsForSharing()])
}
}
private func presentShareSheet(activityItems: [Any]) {
let activityViewController = UIActivityViewController(activityItems: activityItems, applicationActivities: nil)
if let windowScene = UIApplication.shared.connectedScenes.first as? UIWindowScene,
let window = windowScene.windows.first,
let rootViewController = window.rootViewController {
if let popover = activityViewController.popoverPresentationController {
popover.sourceView = window
popover.sourceRect = CGRect(x: window.bounds.midX, y: window.bounds.midY, width: 0, height: 0)
popover.permittedArrowDirections = []
}
rootViewController.present(activityViewController, animated: true)
}
}
/// Formats benchmark results into shareable text format with performance metrics and hashtags
private func formatResultsForSharing() -> String {
let statistics = BenchmarkResultsHelper.shared.processTestResults(results.testResults)
let deviceInfo = BenchmarkResultsHelper.shared.getDeviceInfo()
var shareText = """
📱 MNN LLM Benchmark Results
🤖 Model: \(results.modelDisplayName)
📱 \(deviceInfo)
📅 Completed: \(results.timestamp)
📊 Configuration:
\(statistics.configText)
Performance Results:
"""
if let prefillStats = statistics.prefillStats {
shareText += "\n🔄 Prompt Processing: \(BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(prefillStats))"
}
if let decodeStats = statistics.decodeStats {
shareText += "\n⚡️ Token Generation: \(BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(decodeStats))"
}
let totalMemoryKb = BenchmarkResultsHelper.shared.getTotalSystemMemoryKb()
let memoryInfo = BenchmarkResultsHelper.shared.formatMemoryUsage(
maxMemoryKb: results.maxMemoryKb,
totalKb: totalMemoryKb
)
shareText += "\n💾 Peak Memory: \(memoryInfo.valueText) (\(memoryInfo.labelText))"
shareText += "\n\n📈 Summary:"
shareText += "\n• Total Tokens Processed: \(statistics.totalTokensProcessed)"
shareText += "\n• Number of Tests: \(statistics.totalTests)"
shareText += "\n\n#MNNLLMBenchmark #AIPerformance #MobileAI"
return shareText
}
}
extension View {
func snapshot() -> UIImage? {
let controller = UIHostingController(rootView: self)
let view = controller.view
let targetSize = controller.view.intrinsicContentSize
view?.bounds = CGRect(origin: .zero, size: targetSize)
view?.backgroundColor = .clear
let renderer = UIGraphicsImageRenderer(size: targetSize)
return renderer.image { _ in
view?.drawHierarchy(in: controller.view.bounds, afterScreenUpdates: true)
}
}
}
#Preview {
ResultsCard(
results: BenchmarkResults(
modelDisplayName: "Qwen2.5-1.5B-Instruct",
maxMemoryKb: 1200000, // 1.2 GB in KB
testResults: [],
timestamp: "2025-01-21 14:30:25"
)
)
.padding()
}

View File

@ -0,0 +1,64 @@
//
// StatusCard.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Reusable status display card component for benchmark interface.
* Shows status messages and updates to provide user feedback.
*/
struct StatusCard: View {
let statusMessage: String
var body: some View {
HStack(spacing: 16) {
ZStack {
Circle()
.fill(
LinearGradient(
colors: [Color.benchmarkWarning.opacity(0.2), Color.benchmarkWarning.opacity(0.1)],
startPoint: .topLeading,
endPoint: .bottomTrailing
)
)
.frame(width: 40, height: 40)
Image(systemName: "info.circle")
.font(.system(size: 18, weight: .semibold))
.foregroundColor(.benchmarkWarning)
}
VStack(alignment: .leading, spacing: 4) {
Text("Status Update")
.font(.subheadline)
.fontWeight(.semibold)
.foregroundColor(.primary)
Text(statusMessage)
.font(.subheadline)
.foregroundColor(.benchmarkSecondary)
.fixedSize(horizontal: false, vertical: true)
}
Spacer()
}
.padding(20)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(Color.benchmarkCardBg)
.overlay(
RoundedRectangle(cornerRadius: 16)
.stroke(Color.benchmarkWarning.opacity(0.3), lineWidth: 1)
)
)
}
}
#Preview {
StatusCard(statusMessage: "Initializing benchmark test environment...")
.padding()
}

View File

@ -0,0 +1,78 @@
//
// BenchmarkView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/21.
//
import SwiftUI
/**
* Main benchmark view that provides interface for running performance tests on ML models.
* Features include model selection, progress tracking, and results visualization.
*/
struct BenchmarkView: View {
@StateObject private var viewModel = BenchmarkViewModel()
@State private var showStopConfirmation = false
var body: some View {
ZStack {
ScrollView {
VStack(spacing: 24) {
// Model Selection Section
ModelSelectionCard(
viewModel: viewModel,
showStopConfirmation: $showStopConfirmation
)
// Progress Section
if viewModel.showProgressBar {
ProgressCard(progress: viewModel.currentProgress)
.transition(.asymmetric(
insertion: .scale.combined(with: .opacity),
removal: .opacity
))
}
// Status Section
if !viewModel.statusMessage.isEmpty {
StatusCard(statusMessage: viewModel.statusMessage)
.transition(.slide)
}
// Results Section
if viewModel.showResults, let results = viewModel.benchmarkResults {
ResultsCard(results: results)
.transition(.asymmetric(
insertion: .move(edge: .bottom).combined(with: .opacity),
removal: .opacity
))
}
Spacer(minLength: 20)
}
.padding(.horizontal, 20)
.padding(.vertical, 16)
}
}
.alert("Stop Benchmark", isPresented: $showStopConfirmation) {
Button("Yes", role: .destructive) {
viewModel.onStopBenchmarkTapped()
}
Button("No", role: .cancel) { }
} message: {
Text("Are you sure you want to stop the benchmark test?")
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK") { }
} message: {
Text(viewModel.errorMessage)
}
.onReceive(viewModel.$isRunning) { isRunning in
if isRunning && viewModel.startButtonText.contains("Stop") {
showStopConfirmation = false
}
}
}
}

View File

@ -0,0 +1,44 @@
//
// CommonToolbarView.swift
// MNNLLMiOS
//
// Created by () on 2025/07/18.
//
import SwiftUI
struct CommonToolbarView: ToolbarContent {
@Binding var showHistory: Bool
@Binding var showHistoryButton: Bool
var body: some ToolbarContent {
ToolbarItem(placement: .navigationBarLeading) {
if showHistoryButton {
Button(action: {
showHistory = true
showHistoryButton = false
}) {
Image(systemName: "sidebar.left")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 20, height: 20)
.foregroundColor(.black)
}
}
}
ToolbarItem(placement: .navigationBarTrailing) {
Button(action: {
if let url = URL(string: "https://github.com/alibaba/MNN") {
UIApplication.shared.open(url)
}
}) {
Image(systemName: "star")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 20, height: 20)
.foregroundColor(.black)
}
}
}
}

View File

@ -0,0 +1,37 @@
//
// MNNLLMiOSApp.swift
// LocalModelListView
//
// Created by () on 2025/06/20.
//
import SwiftUI
struct LocalModelListView: View {
@ObservedObject var viewModel: ModelListViewModel
var body: some View {
List {
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.id) { model in
Button(action: {
viewModel.selectModel(model)
}) {
LocalModelRowView(model: model)
}
.listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
SwipeActionsView(model: model, viewModel: viewModel)
}
}
}
.listStyle(.plain)
.refreshable {
await viewModel.fetchModels()
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK", role: .cancel) {}
} message: {
Text(viewModel.errorMessage)
}
}
}

View File

@ -0,0 +1,64 @@
//
// LocalModelRowView.swift
// MNNLLMiOS
//
// Created by () on 2025/6/26.
//
import SwiftUI
struct LocalModelRowView: View {
let model: ModelInfo
private var localizedTags: [String] {
model.localizedTags
}
private var formattedSize: String {
model.formattedSize
}
var body: some View {
HStack(alignment: .center) {
ModelIconView(modelId: model.id)
.frame(width: 40, height: 40)
VStack(alignment: .leading, spacing: 8) {
Text(model.modelName)
.font(.headline)
.fontWeight(.semibold)
.lineLimit(1)
if !localizedTags.isEmpty {
TagsView(tags: localizedTags)
}
HStack {
HStack(alignment: .center, spacing: 2) {
Image(systemName: "folder")
.font(.caption)
.fontWeight(.medium)
.foregroundColor(.gray)
.frame(width: 20, height: 20)
Text(formattedSize)
.font(.caption)
.fontWeight(.medium)
.foregroundColor(.gray)
}
Spacer()
if let lastUsedAt = model.lastUsedAt {
Text("\(lastUsedAt.formatAgo())")
.font(.caption)
.fontWeight(.medium)
.foregroundColor(.gray)
}
}
}
}
}
}

View File

@ -0,0 +1,230 @@
//
// MNNLLMiOSApp.swift
// MainTabView
//
// Created by () on 2025/06/20.
//
import SwiftUI
// MainTabView is the primary view of the app, containing the tab bar and navigation for main sections.
struct MainTabView: View {
// MARK: - State Properties
@State private var showHistory = false
@State private var selectedHistory: ChatHistory? = nil
@State private var histories: [ChatHistory] = ChatHistoryManager.shared.getAllHistory()
@State private var showHistoryButton = true
@State private var showSettings = false
@State private var showWebView = false
@State private var webViewURL: URL?
@State private var navigateToSettings = false
@StateObject private var modelListViewModel = ModelListViewModel()
@State private var selectedTab: Int = 0
private var titles: [String] {
[
NSLocalizedString("Local Model", comment: "本地模型标签"),
NSLocalizedString("Model Market", comment: "模型市场标签"),
NSLocalizedString("Benchmark", comment: "基准测试标签")
]
}
// MARK: - Body
var body: some View {
ZStack {
// Main TabView for navigation between Local Model, Model Market, and Benchmark
TabView(selection: $selectedTab) {
NavigationView {
LocalModelListView(viewModel: modelListViewModel)
.navigationTitle(titles[0])
.navigationBarTitleDisplayMode(.inline)
.navigationBarHidden(false)
.onAppear {
setupNavigationBarAppearance()
}
.toolbar {
CommonToolbarView(
showHistory: $showHistory,
showHistoryButton: $showHistoryButton,
)
}
.background(
ZStack {
NavigationLink(destination: chatDestination, isActive: chatIsActiveBinding) { EmptyView() }
NavigationLink(destination: SettingsView(), isActive: $navigateToSettings) { EmptyView() }
}
)
// Hide TabBar when entering chat or settings view
.toolbar((chatIsActiveBinding.wrappedValue || navigateToSettings) ? .hidden : .visible, for: .tabBar)
}
.tabItem {
Image(systemName: "house.fill")
Text(titles[0])
}
.tag(0)
NavigationView {
ModelListView(viewModel: modelListViewModel)
.navigationTitle(titles[1])
.navigationBarTitleDisplayMode(.inline)
.navigationBarHidden(false)
.onAppear {
setupNavigationBarAppearance()
}
.toolbar {
CommonToolbarView(
showHistory: $showHistory,
showHistoryButton: $showHistoryButton,
)
}
.background(
ZStack {
NavigationLink(destination: chatDestination, isActive: chatIsActiveBinding) { EmptyView() }
NavigationLink(destination: SettingsView(), isActive: $navigateToSettings) { EmptyView() }
}
)
}
.tabItem {
Image(systemName: "doc.text.fill")
Text(titles[1])
}
.tag(1)
NavigationView {
BenchmarkView()
.navigationTitle(titles[2])
.navigationBarTitleDisplayMode(.inline)
.navigationBarHidden(false)
.onAppear {
setupNavigationBarAppearance()
}
.toolbar {
CommonToolbarView(
showHistory: $showHistory,
showHistoryButton: $showHistoryButton,
)
}
.background(
ZStack {
NavigationLink(destination: chatDestination, isActive: chatIsActiveBinding) { EmptyView() }
NavigationLink(destination: SettingsView(), isActive: $navigateToSettings) { EmptyView() }
}
)
}
.tabItem {
Image(systemName: "clock.fill")
Text(titles[2])
}
.tag(2)
}
.onAppear {
setupTabBarAppearance()
}
.tint(.black)
// Overlay for dimming the background when history is shown
if showHistory {
Color.black.opacity(0.5)
.edgesIgnoringSafeArea(.all)
.onTapGesture {
withAnimation(.easeInOut(duration: 0.2)) {
showHistory = false
}
}
}
// Side menu for displaying chat history
SideMenuView(isOpen: $showHistory,
selectedHistory: $selectedHistory,
histories: $histories,
navigateToMainSettings: $navigateToSettings)
.edgesIgnoringSafeArea(.all)
}
.onChange(of: showHistory) { oldValue, newValue in
if !newValue {
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
withAnimation {
showHistoryButton = true
}
}
}
}
.sheet(isPresented: $showWebView) {
if let url = webViewURL {
WebView(url: url)
}
}
}
// MARK: - View Builders
/// Destination view for chat, either from a new model or a history item.
@ViewBuilder
private var chatDestination: some View {
if let model = modelListViewModel.selectedModel {
LLMChatView(modelInfo: model)
.navigationBarHidden(false)
.navigationBarTitleDisplayMode(.inline)
} else if let history = selectedHistory {
let modelInfo = ModelInfo(modelId: history.modelId, isDownloaded: true)
LLMChatView(modelInfo: modelInfo, history: history)
.navigationBarHidden(false)
.navigationBarTitleDisplayMode(.inline)
} else {
EmptyView()
}
}
// MARK: - Bindings
/// Binding to control the activation of the chat view.
private var chatIsActiveBinding: Binding<Bool> {
Binding<Bool>(
get: {
return modelListViewModel.selectedModel != nil || selectedHistory != nil
},
set: { isActive in
if !isActive {
// Record usage when returning from chat
if let model = modelListViewModel.selectedModel {
modelListViewModel.recordModelUsage(modelName: model.modelName)
}
// Clear selections
modelListViewModel.selectedModel = nil
selectedHistory = nil
}
}
)
}
// MARK: - Private Methods
/// Configures the appearance of the navigation bar.
private func setupNavigationBarAppearance() {
let appearance = UINavigationBarAppearance()
appearance.configureWithOpaqueBackground()
appearance.backgroundColor = .white
appearance.shadowColor = .clear
UINavigationBar.appearance().standardAppearance = appearance
UINavigationBar.appearance().compactAppearance = appearance
UINavigationBar.appearance().scrollEdgeAppearance = appearance
}
/// Configures the appearance of the tab bar.
private func setupTabBarAppearance() {
let appearance = UITabBarAppearance()
appearance.configureWithOpaqueBackground()
let selectedColor = UIColor(Color.primaryPurple)
appearance.stackedLayoutAppearance.selected.iconColor = selectedColor
appearance.stackedLayoutAppearance.selected.titleTextAttributes = [.foregroundColor: selectedColor]
UITabBar.appearance().standardAppearance = appearance
UITabBar.appearance().scrollEdgeAppearance = appearance
}
}

View File

@ -0,0 +1,203 @@
//
// TBModelInfo.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Hub
import Foundation
struct ModelInfo: Codable {
// MARK: - Properties
let modelName: String
let tags: [String]
let categories: [String]?
let size_gb: Double?
let vendor: String?
let sources: [String: String]?
let tagTranslations: [String: [String]]?
// Runtime properties
var isDownloaded: Bool = false
var lastUsedAt: Date?
var cachedSize: Int64? = nil
// MARK: - Initialization
init(modelName: String = "",
tags: [String] = [],
categories: [String]? = nil,
size_gb: Double? = nil,
vendor: String? = nil,
sources: [String: String]? = nil,
tagTranslations: [String: [String]]? = nil,
isDownloaded: Bool = false,
lastUsedAt: Date? = nil,
cachedSize: Int64? = nil) {
self.modelName = modelName
self.tags = tags
self.categories = categories
self.size_gb = size_gb
self.vendor = vendor
self.sources = sources
self.tagTranslations = tagTranslations
self.isDownloaded = isDownloaded
self.lastUsedAt = lastUsedAt
self.cachedSize = cachedSize
}
init(modelId: String, isDownloaded: Bool = true) {
let modelName = modelId.components(separatedBy: "/").last ?? modelId
self.init(
modelName: modelName,
tags: [],
sources: ["huggingface": modelId],
isDownloaded: isDownloaded
)
}
// MARK: - Model Identity & Localization
var id: String {
guard let sources = sources else {
return "taobao-mnn/\(modelName)"
}
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
}
var localizedTags: [String] {
let currentLanguage = LanguageManager.shared.currentLanguage
let isChineseLanguage = currentLanguage == "简体中文"
if isChineseLanguage, let translations = tagTranslations {
let languageCode = "zh-Hans"
return translations[languageCode] ?? tags
} else {
return tags
}
}
// MARK: - File System & Path Management
var localPath: String {
let modelScopeId = "taobao-mnn/\(modelName)"
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
}
// MARK: - Size Calculation & Formatting
var formattedSize: String {
if let cached = cachedSize {
return FileOperationManager.shared.formatBytes(cached)
} else if isDownloaded {
return FileOperationManager.shared.formatLocalDirectorySize(at: localPath)
} else if let sizeGb = size_gb {
return String(format: "%.1f GB", sizeGb)
} else {
return "None"
}
}
/// Calculates and caches the local directory size
/// - Returns: The formatted size string and updates cachedSize property
mutating func calculateAndCacheSize() -> String {
if let cached = cachedSize {
return FileOperationManager.shared.formatBytes(cached)
}
if isDownloaded {
do {
let sizeInBytes = try FileOperationManager.shared.calculateDirectorySize(at: localPath)
self.cachedSize = sizeInBytes
return FileOperationManager.shared.formatBytes(sizeInBytes)
} catch {
print("Error calculating directory size: \(error)")
return "Unknown"
}
} else if let sizeGb = size_gb {
return String(format: "%.1f GB", sizeGb)
} else {
return "None"
}
}
// MARK: - Remote Size Calculation
func fetchRemoteSize() async -> Int64? {
let modelScopeId = "taobao-mnn/\(modelName)"
do {
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
return totalSize
} catch {
print("Error fetching remote size for \(id): \(error)")
return nil
}
}
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
let url = try buildURL(
repoPath: repoPath,
path: "/repo/files",
queryItems: [
URLQueryItem(name: "Root", value: root),
URLQueryItem(name: "Revision", value: revision)
]
)
let (data, response) = try await URLSession.shared.data(from: url)
try validateResponse(response)
let modelResponse = try JSONDecoder().decode(ModelResponse.self, from: data)
return modelResponse.data.files
}
private func calculateTotalSize(files: [ModelFile], repoPath: String) async throws -> Int64 {
var totalSize: Int64 = 0
for file in files {
if file.type == "tree" {
let subFiles = try await fetchFileList(repoPath: repoPath, root: file.path, revision: "")
totalSize += try await calculateTotalSize(files: subFiles, repoPath: repoPath)
} else if file.type == "blob" {
totalSize += Int64(file.size)
}
}
return totalSize
}
// MARK: - Network Utilities
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
var components = URLComponents()
components.scheme = "https"
components.host = "modelscope.cn"
components.path = "/api/v1/models/\(repoPath)\(path)"
components.queryItems = queryItems
guard let url = components.url else {
throw ModelScopeError.invalidURL
}
return url
}
private func validateResponse(_ response: URLResponse) throws {
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
throw ModelScopeError.invalidResponse
}
}
// MARK: - Codable
private enum CodingKeys: String, CodingKey {
case modelName, tags, categories, size_gb, vendor, sources, tagTranslations, cachedSize
}
}

View File

@ -0,0 +1,361 @@
//
// ModelListViewModel.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Foundation
import SwiftUI
class ModelListViewModel: ObservableObject {
// MARK: - Published Properties
@Published var models: [ModelInfo] = []
@Published var searchText = ""
@Published var quickFilterTags: [String] = []
@Published var selectedModel: ModelInfo?
@Published var showError = false
@Published var errorMessage = ""
// Download state
@Published private(set) var downloadProgress: [String: Double] = [:]
@Published private(set) var currentlyDownloading: String?
// MARK: - Private Properties
private let modelClient = ModelClient()
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
// MARK: - Model Data Access
public var pinnedModelIds: [String] {
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
}
var allTags: [String] {
Array(Set(models.flatMap { $0.tags }))
}
var allCategories: [String] {
Array(Set(models.compactMap { $0.categories }.flatMap { $0 }))
}
var allVendors: [String] {
Array(Set(models.compactMap { $0.vendor }))
}
var filteredModels: [ModelInfo] {
let filtered = searchText.isEmpty ? models : models.filter { model in
model.id.localizedCaseInsensitiveContains(searchText) ||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
}
let downloaded = filtered.filter { $0.isDownloaded }
let notDownloaded = filtered.filter { !$0.isDownloaded }
return downloaded + notDownloaded
}
// MARK: - Initialization
init() {
Task { @MainActor in
await fetchModels()
}
}
// MARK: - Model Data Management
@MainActor
func fetchModels() async {
do {
let info = try await modelClient.getModelInfo()
self.quickFilterTags = info.quickFilterTags ?? []
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
var fetchedModels = info.models
filterDiffusionModels(fetchedModels: &fetchedModels)
loadCachedSizes(for: &fetchedModels)
sortModels(fetchedModels: &fetchedModels)
self.models = fetchedModels
// Asynchronously fetch size info for both downloaded and undownloaded models
Task {
await fetchModelSizes(for: fetchedModels)
}
} catch {
showError = true
errorMessage = "Error: \(error.localizedDescription)"
}
}
private func loadCachedSizes(for models: inout [ModelInfo]) {
for i in 0..<models.count {
if let cachedSize = ModelStorageManager.shared.getCachedSize(for: models[i].modelName) {
models[i].cachedSize = cachedSize
}
}
}
private func fetchModelSizes(for models: [ModelInfo]) async {
await withTaskGroup(of: Void.self) { group in
for (_, model) in models.enumerated() {
// Handle undownloaded models - fetch remote size
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
group.addTask {
if let size = await model.fetchRemoteSize() {
await MainActor.run {
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[modelIndex].cachedSize = size
ModelStorageManager.shared.setCachedSize(size, for: model.modelName)
}
}
}
}
}
// Handle downloaded models - calculate and cache local directory size
if model.isDownloaded && model.cachedSize == nil {
group.addTask {
do {
let localSize = try FileOperationManager.shared.calculateDirectorySize(at: model.localPath)
await MainActor.run {
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[modelIndex].cachedSize = localSize
ModelStorageManager.shared.setCachedSize(localSize, for: model.modelName)
}
}
} catch {
print("Error calculating local directory size for \(model.modelName): \(error)")
}
}
}
}
}
}
private func filterDiffusionModels(fetchedModels: inout [ModelInfo]) {
let hasDiffusionModels = fetchedModels.contains {
$0.modelName.lowercased().contains("diffusion")
}
if hasDiffusionModels {
fetchedModels = fetchedModels.filter { model in
let name = model.modelName.lowercased()
let tags = model.tags.map { $0.lowercased() }
// Only show GPU diffusion models
if name.contains("diffusion") {
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
}
return true
}
}
for i in 0..<fetchedModels.count {
let model = fetchedModels[i]
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
}
}
private func sortModels(fetchedModels: inout [ModelInfo]) {
let pinned = pinnedModelIds
fetchedModels.sort { (model1, model2) -> Bool in
let isPinned1 = pinned.contains(model1.id)
let isPinned2 = pinned.contains(model2.id)
let isDownloading1 = currentlyDownloading == model1.id
let isDownloading2 = currentlyDownloading == model2.id
// 1. Currently downloading models have highest priority
if isDownloading1 != isDownloading2 {
return isDownloading1
}
// 2. Pinned models have second priority
if isPinned1 != isPinned2 {
return isPinned1
}
// 3. If both are pinned, sort by pin time
if isPinned1 && isPinned2 {
let index1 = pinned.firstIndex(of: model1.id)!
let index2 = pinned.firstIndex(of: model2.id)!
return index1 > index2 // Pinned later comes first
}
// 4. Non-pinned models sorted by download status
if model1.isDownloaded != model2.isDownloaded {
return model1.isDownloaded
}
// 5. If both downloaded, sort by last used time
if model1.isDownloaded {
let date1 = model1.lastUsedAt ?? .distantPast
let date2 = model2.lastUsedAt ?? .distantPast
return date1 > date2
}
return false // Keep original order for not-downloaded
}
}
// MARK: - Model Selection & Usage
@MainActor
func selectModel(_ model: ModelInfo) {
if model.isDownloaded {
selectedModel = model
} else {
Task {
await downloadModel(model)
}
}
}
func recordModelUsage(modelName: String) {
ModelStorageManager.shared.updateLastUsed(for: modelName)
Task { @MainActor in
if let index = self.models.firstIndex(where: { $0.modelName == modelName }) {
self.models[index].lastUsedAt = Date()
self.sortModels(fetchedModels: &self.models)
}
}
}
// MARK: - Download Management
func downloadModel(_ model: ModelInfo) async {
await MainActor.run {
guard currentlyDownloading == nil else { return }
currentlyDownloading = model.id
downloadProgress[model.id] = 0
}
do {
try await modelClient.downloadModel(model: model) { progress in
Task { @MainActor in
self.downloadProgress[model.id] = progress
}
}
await MainActor.run {
if let index = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[index].isDownloaded = true
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
}
}
// Calculate and cache size for newly downloaded model
do {
let localSize = try FileOperationManager.shared.calculateDirectorySize(at: model.localPath)
await MainActor.run {
if let index = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[index].cachedSize = localSize
ModelStorageManager.shared.setCachedSize(localSize, for: model.modelName)
}
}
} catch {
print("Error calculating size for newly downloaded model \(model.modelName): \(error)")
}
} catch {
await MainActor.run {
if case ModelScopeError.downloadCancelled = error {
print("Download was cancelled")
} else {
self.showError = true
self.errorMessage = "Failed to download model: \(error.localizedDescription)"
}
}
}
await MainActor.run {
self.currentlyDownloading = nil
self.downloadProgress.removeValue(forKey: model.id)
}
}
func cancelDownload() async {
let modelId = await MainActor.run { currentlyDownloading }
if let modelId = modelId {
await modelClient.cancelDownload()
await MainActor.run {
self.downloadProgress.removeValue(forKey: modelId)
self.currentlyDownloading = nil
}
print("Download cancelled for model: \(modelId)")
}
}
// MARK: - Pin Management
@MainActor
func pinModel(_ model: ModelInfo) {
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
let pinned = models.remove(at: index)
models.insert(pinned, at: 0)
var pinnedIds = pinnedModelIds
if let existingIndex = pinnedIds.firstIndex(of: model.id) {
pinnedIds.remove(at: existingIndex)
}
pinnedIds.insert(model.id, at: 0)
pinnedModelIds = pinnedIds
}
@MainActor
func unpinModel(_ model: ModelInfo) {
var pinnedIds = pinnedModelIds
if let index = pinnedIds.firstIndex(of: model.id) {
pinnedIds.remove(at: index)
pinnedModelIds = pinnedIds
// Re-sort models after unpinning
sortModels(fetchedModels: &models)
}
}
// MARK: - Model Deletion
func deleteModel(_ model: ModelInfo) async {
guard model.isDownloaded else { return }
do {
// Delete local files
let fileManager = FileManager.default
let modelPath = model.localPath
if fileManager.fileExists(atPath: modelPath) {
try fileManager.removeItem(atPath: modelPath)
}
// Update model state
await MainActor.run {
if let index = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[index].isDownloaded = false
self.models[index].cachedSize = nil
ModelStorageManager.shared.markModelAsNotDownloaded(model.modelName)
}
// Re-sort models after deletion
self.sortModels(fetchedModels: &self.models)
}
} catch {
await MainActor.run {
self.showError = true
self.errorMessage = "Failed to delete model: \(error.localizedDescription)"
}
}
}
}

View File

@ -0,0 +1,24 @@
//
// TBDataResponse.swift
// MNNLLMiOS
//
// Created by () on 2025/7/9.
//
import Foundation
struct TBDataResponse: Codable {
let tagTranslations: [String: String]
let quickFilterTags: [String]?
let models: [ModelInfo]
let metadata: Metadata?
struct Metadata: Codable {
let version: String
let lastUpdated: String
let schemaVersion: String
let totalModels: Int
let supportedPlatforms: [String]
let minAppVersion: String
}
}

View File

@ -0,0 +1,196 @@
//
// ModelClient.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Hub
import Foundation
class ModelClient {
private let maxRetries = 5
private let baseMirrorURL = "https://hf-mirror.com"
private let baseURL = "https://huggingface.co"
private let AliCDNURL = "https://meta.alicdn.com/data/mnn/apis/model_market.json"
// Debug flag to use local mock data instead of network API
private let useLocalMockData = false
private var currentDownloadManager: ModelScopeDownloadManager?
private lazy var baseURLString: String = {
switch ModelSourceManager.shared.selectedSource {
case .huggingFace:
return baseURL
default:
return baseMirrorURL
}
}()
init() {}
func getModelInfo() async throws -> TBDataResponse {
if useLocalMockData {
// Debug mode: use local mock data
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
throw NetworkError.invalidData
}
let data = try Data(contentsOf: url)
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
return mockResponse
} else {
// Production mode: fetch from network API
return try await fetchDataFromAliAPI()
}
}
/**
* Fetches data from the network API with fallback to local mock data
*
* @throws NetworkError if both network request and local fallback fail
*/
private func fetchDataFromAliAPI() async throws -> TBDataResponse {
do {
guard let url = URL(string: AliCDNURL) else {
throw NetworkError.invalidData
}
let (data, response) = try await URLSession.shared.data(from: url)
guard let httpResponse = response as? HTTPURLResponse,
httpResponse.statusCode == 200 else {
throw NetworkError.invalidResponse
}
let apiResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
return apiResponse
} catch {
print("Network request failed: \(error). Falling back to local mock data.")
// Fallback to local mock data if network request fails
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
throw NetworkError.invalidData
}
let data = try Data(contentsOf: url)
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
return mockResponse
}
}
/**
* Downloads a model from the selected source with progress tracking
*
* @param model The ModelInfo object containing model details
* @param progress Progress callback that receives download progress (0.0 to 1.0)
* @throws Various network or file system errors
*/
func downloadModel(model: ModelInfo,
progress: @escaping (Double) -> Void) async throws {
switch ModelSourceManager.shared.selectedSource {
case .modelScope, .modeler:
try await downloadFromModelScope(model, progress: progress)
case .huggingFace:
try await downloadFromHuggingFace(model, progress: progress)
}
}
/**
* Cancels the current download operation
*/
func cancelDownload() async {
if let manager = currentDownloadManager {
await manager.cancelDownload()
currentDownloadManager = nil
print("Download cancelled")
}
}
/**
* Downloads model from ModelScope platform
*
* @param model The ModelInfo object to download
* @param progress Progress callback for download updates
* @throws Download or network related errors
*/
private func downloadFromModelScope(_ model: ModelInfo,
progress: @escaping (Double) -> Void) async throws {
let ModelScopeId = model.id
let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 30
config.timeoutIntervalForResource = 300
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
currentDownloadManager = manager
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
Task { @MainActor in
progress(fileProgress)
}
}
currentDownloadManager = nil
}
/**
* Downloads model from HuggingFace platform with optimized progress updates
*
* This method implements throttling to prevent UI stuttering by limiting
* progress update frequency and filtering out minor progress changes.
*
* @param model The ModelInfo object to download
* @param progress Progress callback for download updates
* @throws Download or network related errors
*/
private func downloadFromHuggingFace(_ model: ModelInfo,
progress: @escaping (Double) -> Void) async throws {
let repo = Hub.Repo(id: model.id)
let modelFiles = ["*.*"]
let mirrorHubApi = HubApi(endpoint: baseURL)
// Progress throttling mechanism to prevent UI stuttering
var lastUpdateTime = Date()
var lastProgress: Double = 0.0
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
let currentProgress = fileProgress.fractionCompleted
let currentTime = Date()
// Check if progress should be updated
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
let progressDiff = abs(currentProgress - lastProgress)
// Update progress if any of these conditions are met:
// 1. Time interval exceeds threshold
// 2. Progress change exceeds threshold
// 3. Progress reaches 100% (download complete)
// 4. Progress is 0% (download start)
if timeDiff >= progressUpdateInterval ||
progressDiff >= progressThreshold ||
currentProgress >= 1.0 ||
currentProgress == 0.0 {
lastUpdateTime = currentTime
lastProgress = currentProgress
// Ensure progress updates are executed on the main thread
Task { @MainActor in
progress(currentProgress)
}
}
}
}
}
enum NetworkError: Error {
case invalidResponse
case invalidData
case downloadFailed
case unknown
}

View File

@ -1,5 +1,5 @@
//
// ModelClient.swift
// ModelScopeDownloadManager.swift
// MNNLLMiOS
//
// Created by () on 2025/2/20.
@ -31,6 +31,11 @@ public actor ModelScopeDownloadManager: Sendable {
private var downloadedSize: Int64 = 0
private var lastUpdatedBytes: Int64 = 0
// Download cancellation related properties
private var isCancelled: Bool = false
private var currentDownloadTask: Task<Void, Error>?
private var currentFileHandle: FileHandle?
// MARK: - Initialization
/// Creates a new ModelScope download manager
@ -83,6 +88,9 @@ public actor ModelScopeDownloadManager: Sendable {
modelName: String,
progress: ((Double) -> Void)? = nil
) async throws {
isCancelled = false
ModelScopeLogger.info("Starting download for modelId: \(modelId)")
let destination = try resolveDestinationPath(base: destinationFolder, modelId: modelName)
@ -100,7 +108,28 @@ public actor ModelScopeDownloadManager: Sendable {
)
}
// MARK: - Private Methods
/// Cancel download
/// Preserve downloaded temporary files to support resume functionality
public func cancelDownload() async {
isCancelled = true
currentDownloadTask?.cancel()
currentDownloadTask = nil
await closeFileHandle()
session.invalidateAndCancel()
ModelScopeLogger.info("Download cancelled, temporary files preserved for resume")
}
// MARK: - Private Methods - Progress Management
private func updateProgress(_ progress: Double, callback: @escaping (Double) -> Void) {
Task { @MainActor in
callback(progress)
}
}
private func fetchFileList(
root: String,
@ -131,6 +160,8 @@ public actor ModelScopeDownloadManager: Sendable {
var lastError: Error?
for attempt in 1...maxRetries {
if isCancelled { break }
do {
print("Attempt \(attempt) of \(maxRetries) for file: \(file.name)")
try await downloadFileWithRetry(
@ -162,6 +193,11 @@ public actor ModelScopeDownloadManager: Sendable {
destinationPath: String,
onProgress: @escaping (Int64) -> Void
) async throws {
if isCancelled {
throw ModelScopeError.downloadCancelled
}
let session = self.session
ModelScopeLogger.info("Starting download for file: \(file.name)")
@ -209,28 +245,45 @@ public actor ModelScopeDownloadManager: Sendable {
ModelScopeLogger.debug("Requesting URL: \(url)")
return try await withCheckedThrowingContinuation { continuation in
Task {
currentDownloadTask = Task {
do {
let (asyncBytes, response) = try await session.bytes(for: request)
ModelScopeLogger.debug("Response status code: \((response as? HTTPURLResponse)?.statusCode ?? -1)")
try validateResponse(response)
let fileHandle = try FileHandle(forWritingTo: tempURL)
self.currentFileHandle = fileHandle
if resumeOffset > 0 {
try fileHandle.seek(toOffset: UInt64(resumeOffset))
}
var downloadedBytes: Int64 = resumeOffset
var bytesCount = 0
for try await byte in asyncBytes {
// Frequently check cancellation status
if isCancelled {
try fileHandle.close()
self.currentFileHandle = nil
// Don't delete temp files when cancelled, preserve resume functionality
continuation.resume(throwing: ModelScopeError.downloadCancelled)
return
}
try fileHandle.write(contentsOf: [byte])
downloadedBytes += 1
if downloadedBytes % 1024 == 0 {
bytesCount += 1
// 64KB * 5 1KB
if bytesCount >= 64 * 1024 * 5 {
onProgress(downloadedBytes)
bytesCount = 0
}
}
try fileHandle.close()
self.currentFileHandle = nil
let finalSize = try FileManager.default.attributesOfItem(atPath: tempURL.path)[.size] as? Int64 ?? 0
guard finalSize == file.size else {
@ -250,8 +303,16 @@ public actor ModelScopeDownloadManager: Sendable {
onProgress(downloadedBytes)
continuation.resume()
} catch {
ModelScopeLogger.error("Download failed: \(error.localizedDescription)")
storage.clearFileStatus(at: destination.path)
// Clean up file handle when handling errors
if let handle = self.currentFileHandle {
try? handle.close()
self.setCurrentFileHandle(nil)
}
if !isCancelled {
ModelScopeLogger.error("Download failed: \(error.localizedDescription)")
storage.clearFileStatus(at: destination.path)
}
continuation.resume(throwing: error)
}
}
@ -266,6 +327,10 @@ public actor ModelScopeDownloadManager: Sendable {
) async throws {
ModelScopeLogger.info("Starting download with \(files.count) files")
if isCancelled {
throw ModelScopeError.downloadCancelled
}
func calculateTotalSize(files: [ModelFile]) async throws -> Int64 {
var size: Int64 = 0
for file in files {
@ -288,6 +353,11 @@ public actor ModelScopeDownloadManager: Sendable {
}
for file in files {
if Task.isCancelled || isCancelled {
throw ModelScopeError.downloadCancelled
}
ModelScopeLogger.debug("Processing: \(file.name), type: \(file.type)")
if file.type == "tree" {
@ -317,15 +387,7 @@ public actor ModelScopeDownloadManager: Sendable {
destinationPath: destinationPath,
onProgress: { downloadedBytes in
let currentProgress = Double(self.downloadedSize + downloadedBytes) / Double(self.totalSize)
progress(currentProgress)
// 1MB = 1,024 * 1,024
let bytesDelta = self.downloadedSize - self.lastUpdatedBytes
if bytesDelta >= 1_024 * 1_024 {
self.lastUpdatedBytes = self.downloadedSize
DispatchQueue.main.async {
progress(currentProgress)
}
}
self.updateProgress(currentProgress, callback: progress)
},
maxRetries: 500,
retryDelay: 1.0
@ -340,9 +402,42 @@ public actor ModelScopeDownloadManager: Sendable {
ModelScopeLogger.debug("File exists: \(file.name)")
}
progress(Double(downloadedSize) / Double(totalSize))
let currentProgress = Double(downloadedSize) / Double(totalSize)
updateProgress(currentProgress, callback: progress)
}
}
Task { @MainActor in
progress(1.0)
}
}
private func resetDownloadState() async {
totalFiles = 0
downloadedFiles = 0
totalSize = 0
downloadedSize = 0
lastUpdatedBytes = 0
}
private func resetCancelStatus() {
isCancelled = false
totalFiles = 0
downloadedFiles = 0
totalSize = 0
downloadedSize = 0
lastUpdatedBytes = 0
}
private func closeFileHandle() async {
do {
try currentFileHandle?.close()
currentFileHandle = nil
} catch {
print("Error closing file handle: \(error)")
}
}
private func buildURL(
@ -410,4 +505,27 @@ public actor ModelScopeDownloadManager: Sendable {
return modelScopePath.path
}
private func setCurrentFileHandle(_ handle: FileHandle?) {
currentFileHandle = handle
}
private func getTempFileSize(for file: ModelFile, destinationPath: String) -> Int64 {
let modelHash = repoPath.hash
let fileHash = file.path.hash
let tempURL = FileManager.default.temporaryDirectory
.appendingPathComponent("model_\(modelHash)_file_\(fileHash)_\(file.name.sanitizedPath).tmp")
guard fileManager.fileExists(atPath: tempURL.path) else {
return 0
}
do {
let attributes = try fileManager.attributesOfItem(atPath: tempURL.path)
return attributes[.size] as? Int64 ?? 0
} catch {
ModelScopeLogger.error("Failed to get temp file size for \(file.name): \(error)")
return 0
}
}
}

View File

@ -10,6 +10,7 @@ import Foundation
public enum ModelScopeError: Error {
case invalidURL
case invalidResponse
case downloadCancelled
case downloadFailed(Error)
case fileSystemError(Error)
case invalidData

View File

@ -0,0 +1,65 @@
//
// CustomPopupMenu.swift
// MNNLLMiOS
//
// Created by () on 2025/6/30.
//
import SwiftUI
struct CustomPopupMenu: View {
@Binding var isPresented: Bool
@Binding var selectedSource: ModelSource
let anchorFrame: CGRect
var body: some View {
GeometryReader { geometry in
ZStack(alignment: .top) {
Color.black.opacity(0.3)
.frame(maxWidth: .infinity)
.frame(height: UIScreen.main.bounds.height - anchorFrame.maxY)
.offset(y: anchorFrame.maxY - 10)
.onTapGesture {
isPresented = false
}
VStack(spacing: 0) {
ForEach(ModelSource.allCases) { source in
Button {
selectedSource = source
ModelSourceManager.shared.updateSelectedSource(source)
isPresented = false
} label: {
HStack {
Text(source.description)
.font(.system(size: 12, weight: .regular))
.foregroundColor(source == selectedSource ? .primaryBlue : .black)
Spacer()
if source == selectedSource {
Image(systemName: "checkmark.circle")
.foregroundColor(.primaryBlue)
}
}
.frame(maxWidth: .infinity)
.padding()
.background(.white)
}
Divider()
}
}
.background(Color.white)
.cornerRadius(8)
.shadow(color: .black.opacity(0.1), radius: 5, x: 0, y: 5)
.frame(width: geometry.size.width)
.position(
x: geometry.size.width / 2,
y: anchorFrame.maxY - 24
)
}
}
.transition(.opacity)
.animation(.spring(response: 0.3, dampingFraction: 0.8, blendDuration: 0), value: isPresented)
}
}

View File

@ -0,0 +1,33 @@
//
// FilterButton.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct FilterButton: View {
@Binding var showFilterMenu: Bool
@Binding var selectedTags: Set<String>
@Binding var selectedCategories: Set<String>
@Binding var selectedVendors: Set<String>
var body: some View {
Button(action: {
showFilterMenu.toggle()
}) {
Image(systemName: "line.3.horizontal.decrease.circle")
.font(.system(size: 20))
.foregroundColor(.primary)
}
.sheet(isPresented: $showFilterMenu) {
FilterMenuView(
selectedTags: $selectedTags,
selectedCategories: $selectedCategories,
selectedVendors: $selectedVendors
)
.presentationDetents([.large])
}
}
}

View File

@ -0,0 +1,93 @@
//
// FilterMenuView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct FilterMenuView: View {
@Environment(\.dismiss) private var dismiss
@StateObject private var viewModel = ModelListViewModel()
@Binding var selectedTags: Set<String>
@Binding var selectedCategories: Set<String>
@Binding var selectedVendors: Set<String>
var body: some View {
NavigationView {
ScrollView {
VStack(alignment: .leading, spacing: 24) {
VStack(alignment: .leading, spacing: 12) {
Text("filter.byTag")
.font(.headline)
.fontWeight(.semibold)
LazyVGrid(columns: Array(repeating: GridItem(.flexible()), count: 2), spacing: 8) {
ForEach(viewModel.allTags.sorted(), id: \.self) { tag in
FilterOptionRow(
text: TagTranslationManager.shared.getLocalizedTag(tag),
isSelected: selectedTags.contains(tag)
) {
if selectedTags.contains(tag) {
selectedTags.remove(tag)
} else {
selectedTags.insert(tag)
}
}
}
}
}
Divider()
VStack(alignment: .leading, spacing: 12) {
Text("filter.byVendor")
.font(.headline)
.fontWeight(.semibold)
LazyVGrid(columns: Array(repeating: GridItem(.flexible()), count: 2), spacing: 8) {
ForEach(viewModel.allVendors.sorted(), id: \.self) { vendor in
FilterOptionRow(
text: vendor,
isSelected: selectedVendors.contains(vendor)
) {
if selectedVendors.contains(vendor) {
selectedVendors.remove(vendor)
} else {
selectedVendors.insert(vendor)
}
}
}
}
}
Spacer(minLength: 100)
}
.padding()
}
.navigationTitle("filter.title")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .navigationBarLeading) {
Button("button.clear") {
selectedTags.removeAll()
selectedCategories.removeAll()
selectedVendors.removeAll()
}
}
ToolbarItem(placement: .navigationBarTrailing) {
Button("button.done") {
dismiss()
}
}
}
}
.onAppear {
Task {
await viewModel.fetchModels()
}
}
}
}

View File

@ -0,0 +1,40 @@
//
// FilterOptionRow.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct FilterOptionRow: View {
let text: String
let isSelected: Bool
let onTap: () -> Void
var body: some View {
Button(action: onTap) {
HStack {
Text(text)
.font(.system(size: 14))
.foregroundColor(.primary)
Spacer()
if isSelected {
Image(systemName: "checkmark.circle.fill")
.foregroundColor(.accentColor)
} else {
Image(systemName: "circle")
.foregroundColor(.secondary)
}
}
.padding(.horizontal, 12)
.padding(.vertical, 8)
.background(
RoundedRectangle(cornerRadius: 8)
.fill(isSelected ? Color.accentColor.opacity(0.1) : Color(.systemGray6))
)
}
}
}

View File

@ -0,0 +1,28 @@
//
// FilterTagChip.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct FilterTagChip: View {
let text: String
let isSelected: Bool
let onTap: () -> Void
var body: some View {
Button(action: onTap) {
Text(text)
.font(.system(size: 12, weight: .medium))
.foregroundColor(isSelected ? .white : .primary)
.padding(.horizontal, 12)
.padding(.vertical, 6)
.background(
RoundedRectangle(cornerRadius: 16)
.fill(isSelected ? Color.primaryPurple : Color(.systemGray6))
)
}
}
}

View File

@ -0,0 +1,33 @@
//
// QuickFilterTags.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct QuickFilterTags: View {
let tags: [String]
@Binding var selectedTags: Set<String>
var body: some View {
ScrollView(.horizontal, showsIndicators: false) {
HStack(spacing: 8) {
ForEach(tags, id: \.self) { tag in
FilterTagChip(
text: TagTranslationManager.shared.getLocalizedTag(tag),
isSelected: selectedTags.contains(tag)
) {
if selectedTags.contains(tag) {
selectedTags.remove(tag)
} else {
selectedTags.insert(tag)
}
}
}
}
.padding(.horizontal, 16)
}
}
}

View File

@ -0,0 +1,47 @@
//
// SourceSelector.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct SourceSelector: View {
@Binding var selectedSource: ModelSource
@Binding var showSourceMenu: Bool
let onSourceChange: (ModelSource) -> Void
var body: some View {
Menu {
ForEach(ModelSource.allCases) { source in
Button(action: {
onSourceChange(source)
}) {
HStack {
Text(source.rawValue)
if source == selectedSource {
Image(systemName: "checkmark")
}
}
}
}
} label: {
HStack(spacing: 4) {
Text("modelSource.title")
.font(.system(size: 12, weight: .medium))
.foregroundColor(.primary)
Text(selectedSource.rawValue)
.font(.system(size: 12, weight: .regular))
.foregroundColor(.primary)
Image(systemName: "chevron.down")
.font(.system(size: 10))
.foregroundColor(.primary)
}
.padding(.horizontal, 6)
.padding(.vertical, 6)
}
}
}

View File

@ -0,0 +1,56 @@
//
// ToolbarView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct ToolbarView: View {
@ObservedObject var viewModel: ModelListViewModel
@Binding var selectedSource: ModelSource
@Binding var showSourceMenu: Bool
@Binding var selectedTags: Set<String>
@Binding var selectedCategories: Set<String>
@Binding var selectedVendors: Set<String>
let quickFilterTags: [String]
@Binding var showFilterMenu: Bool
let onSourceChange: (ModelSource) -> Void
var body: some View {
VStack(spacing: 12) {
HStack {
SourceSelector(
selectedSource: $selectedSource,
showSourceMenu: $showSourceMenu,
onSourceChange: onSourceChange
)
//
QuickFilterTags(
tags: quickFilterTags,
selectedTags: $selectedTags
)
Spacer()
FilterButton(
showFilterMenu: $showFilterMenu,
selectedTags: $selectedTags,
selectedCategories: $selectedCategories,
selectedVendors: $selectedVendors
)
}
.padding(.horizontal, 16)
}
.padding(.vertical, 8)
.background(Color(.systemBackground))
.overlay(
Rectangle()
.frame(height: 0.5)
.foregroundColor(Color(.separator)),
alignment: .bottom
)
}
}

View File

@ -0,0 +1,146 @@
//
// ModelListView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct ModelListView: View {
@ObservedObject var viewModel: ModelListViewModel
@State private var searchText = ""
@State private var selectedSource = ModelSourceManager.shared.selectedSource
@State private var showSourceMenu = false
@State private var selectedTags: Set<String> = []
@State private var selectedCategories: Set<String> = []
@State private var selectedVendors: Set<String> = []
@State private var showFilterMenu = false
var body: some View {
ScrollView {
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
Section {
modelListSection
} header: {
toolbarSection
}
}
}
.searchable(text: $searchText, prompt: "Search models...")
.onChange(of: searchText) { _, newValue in
viewModel.searchText = newValue
}
.refreshable {
await viewModel.fetchModels()
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK") { }
} message: {
Text(viewModel.errorMessage)
}
}
// Extract model list section as independent view
@ViewBuilder
private var modelListSection: some View {
LazyVStack(spacing: 8) {
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
modelRowView(model: model, index: index)
if index < filteredModels.count - 1 {
Divider()
.padding(.leading, 60)
}
}
}
.padding(.vertical, 8)
}
// Extract toolbar section as independent view
@ViewBuilder
private var toolbarSection: some View {
ToolbarView(
viewModel: viewModel, selectedSource: $selectedSource,
showSourceMenu: $showSourceMenu,
selectedTags: $selectedTags,
selectedCategories: $selectedCategories,
selectedVendors: $selectedVendors,
quickFilterTags: viewModel.quickFilterTags,
showFilterMenu: $showFilterMenu,
onSourceChange: handleSourceChange
)
}
@ViewBuilder
private func modelRowView(model: ModelInfo, index: Int) -> some View {
ModelRowView(
model: model,
viewModel: viewModel,
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
isDownloading: viewModel.currentlyDownloading == model.id,
isOtherDownloading: isOtherDownloadingCheck(model: model)
) {
Task {
await viewModel.downloadModel(model)
}
}
.padding(.horizontal, 16)
}
// Extract complex boolean logic as independent method
private func isOtherDownloadingCheck(model: ModelInfo) -> Bool {
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
}
// Extract source change handling logic as independent method
private func handleSourceChange(_ source: ModelSource) {
ModelSourceManager.shared.updateSelectedSource(source)
selectedSource = source
Task {
await viewModel.fetchModels()
}
}
// Filter models based on selected tags, categories and vendors
private var filteredModels: [ModelInfo] {
let baseFiltered = viewModel.filteredModels
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
return baseFiltered
}
return baseFiltered.filter { model in
let tagMatch = checkTagMatch(model: model)
let categoryMatch = checkCategoryMatch(model: model)
let vendorMatch = checkVendorMatch(model: model)
return tagMatch && categoryMatch && vendorMatch
}
}
// Extract tag matching logic as independent method
private func checkTagMatch(model: ModelInfo) -> Bool {
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
model.localizedTags.contains { tag in
tag.localizedCaseInsensitiveContains(selectedTag)
}
}
}
// Extract category matching logic as independent method
private func checkCategoryMatch(model: ModelInfo) -> Bool {
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
model.categories?.contains { category in
category.localizedCaseInsensitiveContains(selectedCategory)
} ?? false
}
}
// Extract vendor matching logic as independent method
private func checkVendorMatch(model: ModelInfo) -> Bool {
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
}
}
}

View File

@ -0,0 +1,43 @@
//
// ActionButtonsView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
// MARK: -
struct ActionButtonsView: View {
let model: ModelInfo
@ObservedObject var viewModel: ModelListViewModel
let downloadProgress: Double
let isDownloading: Bool
let isOtherDownloading: Bool
let formattedSize: String
let onDownload: () -> Void
@Binding var showDeleteAlert: Bool
var body: some View {
VStack(alignment: .center, spacing: 4) {
if model.isDownloaded {
//
DownloadedButtonView(showDeleteAlert: $showDeleteAlert)
} else if isDownloading {
//
DownloadingButtonView(
viewModel: viewModel,
downloadProgress: downloadProgress
)
} else {
//
PendingDownloadButtonView(
isOtherDownloading: isOtherDownloading,
formattedSize: formattedSize,
onDownload: onDownload
)
}
}
.frame(width: 60)
}
}

View File

@ -0,0 +1,30 @@
//
// DownloadedButtonView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
// MARK: -
struct DownloadedButtonView: View {
@Binding var showDeleteAlert: Bool
var body: some View {
Button(action: { showDeleteAlert = true }) {
VStack(spacing: 2) {
Image(systemName: "trash")
.font(.system(size: 16))
.foregroundColor(.primary.opacity(0.8))
Text(LocalizedStringKey("button.downloaded"))
.font(.caption2)
.foregroundColor(.secondary)
.lineLimit(1)
.minimumScaleFactor(0.8)
.allowsTightening(true)
}
}
}
}

View File

@ -0,0 +1,32 @@
//
// DownloadingButtonView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
// MARK: -
struct DownloadingButtonView: View {
@ObservedObject var viewModel: ModelListViewModel
let downloadProgress: Double
var body: some View {
Button(action: {
Task {
await viewModel.cancelDownload()
}
}) {
VStack(spacing: 2) {
ProgressView(value: downloadProgress)
.progressViewStyle(CircularProgressViewStyle(tint: .accentColor))
.frame(width: 24, height: 24)
Text(String(format: "%.2f%%", downloadProgress * 100))
.font(.caption2)
.foregroundColor(.secondary)
}
}
}
}

View File

@ -0,0 +1,23 @@
//
// PendingDownloadButtonView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
struct PendingDownloadButtonView: View {
let isOtherDownloading: Bool
let formattedSize: String
let onDownload: () -> Void
var body: some View {
Button(action: onDownload) {
Image(systemName: "arrow.down.circle.fill")
.font(.title2)
.foregroundColor(isOtherDownloading ? .secondary : .primaryPurple)
}
.disabled(isOtherDownloading)
}
}

View File

@ -0,0 +1,25 @@
//
// TagChip.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
// MARK: -
struct TagChip: View {
let text: String
var body: some View {
Text(TagTranslationManager.shared.getLocalizedTag(text))
.font(.caption)
.foregroundColor(.secondary)
.padding(.horizontal, 8)
.padding(.vertical, 3)
.background(
RoundedRectangle(cornerRadius: 8)
.stroke(Color.secondary.opacity(0.3), lineWidth: 0.5)
)
}
}

View File

@ -0,0 +1,25 @@
//
// TagsView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
// MARK: -
struct TagsView: View {
let tags: [String]
var body: some View {
ScrollView(.horizontal, showsIndicators: false) {
HStack(spacing: 6) {
ForEach(tags, id: \.self) { tag in
TagChip(text: tag)
}
}
.padding(.horizontal, 1)
}
.frame(height: 25)
}
}

View File

@ -0,0 +1,107 @@
//
// ModelRowView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct ModelRowView: View {
let model: ModelInfo
@ObservedObject var viewModel: ModelListViewModel
let downloadProgress: Double
let isDownloading: Bool
let isOtherDownloading: Bool
let onDownload: () -> Void
@State private var showDeleteAlert = false
private var localizedTags: [String] {
model.localizedTags
}
private var formattedSize: String {
model.formattedSize
}
var body: some View {
HStack(alignment: .center, spacing: 0) {
ModelIconView(modelId: model.id)
.frame(width: 40, height: 40)
VStack(alignment: .leading, spacing: 6) {
Text(model.modelName)
.font(.headline)
.fontWeight(.semibold)
.lineLimit(1)
if !localizedTags.isEmpty {
TagsView(tags: localizedTags)
}
HStack(alignment: .center, spacing: 2) {
Image(systemName: "folder")
.font(.caption)
.fontWeight(.medium)
.foregroundColor(.gray)
.frame(width: 20, height: 20)
Text(formattedSize)
.font(.caption)
.fontWeight(.medium)
.foregroundColor(.gray)
}
}
.padding(.leading, 8)
Spacer()
VStack {
Spacer()
ActionButtonsView(
model: model,
viewModel: viewModel,
downloadProgress: downloadProgress,
isDownloading: isDownloading,
isOtherDownloading: isOtherDownloading,
formattedSize: formattedSize,
onDownload: onDownload,
showDeleteAlert: $showDeleteAlert
)
Spacer()
}
}
.padding(.vertical, 8)
.contentShape(Rectangle())
.onTapGesture {
handleRowTap()
}
.alert(LocalizedStringKey("alert.deleteModel.title"), isPresented: $showDeleteAlert) {
Button("Delete", role: .destructive) {
Task {
await viewModel.deleteModel(model)
}
}
Button("Cancel", role: .cancel) { }
} message: {
Text(LocalizedStringKey("alert.deleteModel.message"))
}
}
private func handleRowTap() {
if model.isDownloaded {
return
} else if isDownloading {
Task {
await viewModel.cancelDownload()
}
} else if !isOtherDownloading {
onDownload()
}
}
}

View File

@ -14,8 +14,10 @@ struct SearchBar: View {
HStack {
Image(systemName: "magnifyingglass")
.foregroundColor(.gray)
.padding(.horizontal, 10)
TextField("Search models...", text: $text)
.font(.system(size: 12, weight: .regular))
.textFieldStyle(RoundedBorderTextFieldStyle())
.autocapitalization(.none)
.disableAutocorrection(true)

View File

@ -0,0 +1,40 @@
//
// SwipeActionsView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
struct SwipeActionsView: View {
let model: ModelInfo
@ObservedObject var viewModel: ModelListViewModel
var body: some View {
if viewModel.pinnedModelIds.contains(model.id) {
Button {
viewModel.unpinModel(model)
} label: {
Label(LocalizedStringKey("button.unpin"), systemImage: "pin.slash")
}.tint(.gray)
} else {
Button {
viewModel.pinModel(model)
} label: {
Label(LocalizedStringKey("button.pin"), systemImage: "pin")
}.tint(.primaryBlue)
}
if model.isDownloaded {
Button(role: .destructive) {
Task {
await viewModel.deleteModel(model)
}
} label: {
Label("Delete", systemImage: "trash")
}
.tint(.primaryRed)
}
}
}

View File

@ -1,5 +1,5 @@
//
// ModelDownloadStorage.swift
// ModelSource.swift
// MNNLLMiOS
//
// Created by () on 2025/2/20.
@ -7,11 +7,13 @@
import Foundation
public enum ModelSource: String, CaseIterable {
public enum ModelSource: String, CaseIterable, Identifiable {
case modelScope = "ModelScope"
case huggingFace = "Hugging Face"
case huggingFace = "HuggingFace"
case modeler = "Modeler"
public var id: Self { self }
var description: String {
switch self {
case .modelScope:

View File

@ -1,5 +1,5 @@
//
// ModelDownloadStorage.swift
// ModelSourceManager.swift
// MNNLLMiOS
//
// Created by () on 2025/2/20.

View File

@ -0,0 +1,101 @@
//
// ModelStorageManager.swift
// MNNLLMiOS
//
// Created by () on 2025/1/10.
//
import Foundation
class ModelStorageManager {
static let shared = ModelStorageManager()
private let userDefaults = UserDefaults.standard
private let downloadedModelsKey = "com.mnnllm.downloadedModels"
private let lastUsedModelKey = "com.mnnllm.lastUsedModels"
private let cachedSizesKey = "com.mnnllm.cachedSizes"
private init() {}
var lastUsedModels: [String: Date] {
get {
userDefaults.dictionary(forKey: lastUsedModelKey) as? [String: Date] ?? [:]
}
set {
userDefaults.set(newValue, forKey: lastUsedModelKey)
}
}
func updateLastUsed(for modelName: String) {
var models = lastUsedModels
models[modelName] = Date()
lastUsedModels = models
}
func getLastUsed(for modelName: String) -> Date? {
return lastUsedModels[modelName]
}
var downloadedModels: [String] {
get {
userDefaults.array(forKey: downloadedModelsKey) as? [String] ?? []
}
set {
userDefaults.set(newValue, forKey: downloadedModelsKey)
}
}
func clearDownloadStatus(for modelName: String) {
var models = downloadedModels
models.removeAll { $0 == modelName }
downloadedModels = models
}
func isModelDownloaded(_ modelName: String) -> Bool {
downloadedModels.contains(modelName)
}
func markModelAsDownloaded(_ modelName: String) {
var models = downloadedModels
if !models.contains(modelName) {
models.append(modelName)
downloadedModels = models
}
}
func markModelAsNotDownloaded(_ modelName: String) {
var models = downloadedModels
models.removeAll { $0 == modelName }
downloadedModels = models
// Also clear cached size when model is marked as not downloaded
clearCachedSize(for: modelName)
}
// MARK: - Cached Size Management
var cachedSizes: [String: Int64] {
get {
userDefaults.dictionary(forKey: cachedSizesKey) as? [String: Int64] ?? [:]
}
set {
userDefaults.set(newValue, forKey: cachedSizesKey)
}
}
func getCachedSize(for modelName: String) -> Int64? {
return cachedSizes[modelName]
}
func setCachedSize(_ size: Int64, for modelName: String) {
var sizes = cachedSizes
sizes[modelName] = size
cachedSizes = sizes
}
func clearCachedSize(for modelName: String) {
var sizes = cachedSizes
sizes.removeValue(forKey: modelName)
cachedSizes = sizes
}
}

View File

@ -0,0 +1,29 @@
//
// TagTranslationManager.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Foundation
class TagTranslationManager {
static let shared = TagTranslationManager()
private var tagTranslations: [String: String] = [:]
private init() {}
func loadTagTranslations(_ translations: [String: String]) {
tagTranslations = translations
}
func getLocalizedTag(_ tag: String) -> String {
let currentLanguage = LanguageManager.shared.currentLanguage
let isChineseLanguage = currentLanguage == "zh-Hans"
if isChineseLanguage, let translation = tagTranslations[tag] {
return translation
}
return tag
}
}

View File

@ -1,43 +0,0 @@
//
// ModelClient.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import Hub
import Foundation
struct ModelInfo: Codable {
let modelId: String
let createdAt: String
let downloads: Int
let tags: [String]
var name: String {
modelId.removingTaobaoPrefix()
}
var isDownloaded: Bool = false
var localPath: String {
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelId)).path
}
private enum CodingKeys: String, CodingKey {
case modelId
case tags
case downloads
case createdAt
}
}
struct RepoInfo: Codable {
let modelId: String
let sha: String
let siblings: [Sibling]
struct Sibling: Codable {
let rfilename: String
}
}

View File

@ -1,155 +0,0 @@
//
// ModelListViewModel.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import Foundation
@MainActor
class ModelListViewModel: ObservableObject {
@Published private(set) var models: [ModelInfo] = []
@Published private(set) var downloadProgress: [String: Double] = [:]
@Published private(set) var currentlyDownloading: String?
@Published var showError = false
@Published var errorMessage = ""
@Published var searchText = ""
@Published var selectedModel: ModelInfo?
private let modelClient = ModelClient()
var filteredModels: [ModelInfo] {
let filteredModels = searchText.isEmpty ? models : models.filter { model in
model.modelId.localizedCaseInsensitiveContains(searchText) ||
model.tags.contains { $0.localizedCaseInsensitiveContains(searchText) }
}
let downloadedModels = filteredModels.filter { $0.isDownloaded }
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded }
return downloadedModels + notDownloadedModels
}
init() {
Task {
await fetchModels()
}
}
func fetchModels() async {
do {
var fetchedModels = try await modelClient.getModelList()
let hasDiffusionModels = fetchedModels.contains {
$0.name.lowercased().contains("diffusion")
}
if hasDiffusionModels {
fetchedModels = fetchedModels.filter { model in
let name = model.name.lowercased()
let tags = model.tags.map { $0.lowercased() }
// only show gpu diffusion
if name.contains("diffusion") {
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
}
return true
}
}
for i in 0..<fetchedModels.count {
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(fetchedModels[i].modelId)
}
models = fetchedModels
} catch {
showError = true
errorMessage = "Error: \(error.localizedDescription)"
}
}
func selectModel(_ model: ModelInfo) {
if model.isDownloaded {
selectedModel = model
} else {
Task {
await downloadModel(model)
}
}
}
func downloadModel(_ model: ModelInfo) async {
guard currentlyDownloading == nil else { return }
currentlyDownloading = model.modelId
downloadProgress[model.modelId] = 0
Task(priority: .background) {
do {
try await modelClient.downloadModel(model: model) { progress in
Task { @MainActor in
DispatchQueue.main.async {
self.downloadProgress[model.modelId] = progress
}
}
}
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) {
models[index].isDownloaded = true
DispatchQueue.main.async {
ModelStorageManager.shared.markModelAsDownloaded(model.modelId)
}
}
} catch {
showError = true
errorMessage = "Failed to download model: \(error.localizedDescription)"
}
currentlyDownloading = nil
}
}
func deleteModel(_ model: ModelInfo) async {
do {
let fileManager = FileManager.default
let modelPath = URL.init(filePath: model.localPath)
//
if let files = try? fileManager.contentsOfDirectory(
at: modelPath,
includingPropertiesForKeys: nil,
options: [.skipsHiddenFiles]
) {
//
let storage = ModelDownloadStorage()
for file in files {
storage.clearFileStatus(at: file.path)
}
}
//
if fileManager.fileExists(atPath: modelPath.path) {
try fileManager.removeItem(at: modelPath)
}
await MainActor.run {
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) {
models[index].isDownloaded = false
ModelStorageManager.shared.clearDownloadStatus(for: model.modelId)
}
if selectedModel?.modelId == model.modelId {
selectedModel = nil
}
}
} catch {
print("Error deleting model: \(error)")
await MainActor.run {
self.errorMessage = "Failed to delete model: \(error.localizedDescription)"
self.showError = true
}
}
}
}

View File

@ -1,44 +0,0 @@
//
// ModelStorageManager.swift
// MNNLLMiOS
//
// Created by () on 2025/1/10.
//
import Foundation
class ModelStorageManager {
static let shared = ModelStorageManager()
private let userDefaults = UserDefaults.standard
private let downloadedModelsKey = "com.mnnllm.downloadedModels"
private init() {}
var downloadedModels: [String] {
get {
userDefaults.array(forKey: downloadedModelsKey) as? [String] ?? []
}
set {
userDefaults.set(newValue, forKey: downloadedModelsKey)
}
}
func clearDownloadStatus(for modelId: String) {
var models = downloadedModels
models.removeAll { $0 == modelId }
downloadedModels = models
}
func isModelDownloaded(_ modelId: String) -> Bool {
downloadedModels.contains(modelId)
}
func markModelAsDownloaded(_ modelId: String) {
var models = downloadedModels
if !models.contains(modelId) {
models.append(modelId)
downloadedModels = models
}
}
}

View File

@ -1,110 +0,0 @@
//
// ModelClient.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import Hub
import Foundation
class ModelClient {
private let baseMirrorURL = "https://hf-mirror.com"
private let baseURL = "https://huggingface.co"
private let maxRetries = 5
private lazy var baseURLString: String = {
switch ModelSourceManager.shared.selectedSource {
case .huggingFace:
return baseURL
default:
return baseMirrorURL
}
}()
init() {}
func getModelList() async throws -> [ModelInfo] {
let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
return try await performRequest(url: url, retries: maxRetries)
}
func getRepoInfo(repoName: String, revision: String) async throws -> RepoInfo {
let url = URL(string: "\(baseURLString)/api/models/\(repoName)")!
return try await performRequest(url: url, retries: maxRetries)
}
@MainActor
func downloadModel(model: ModelInfo,
progress: @escaping (Double) -> Void) async throws {
switch ModelSourceManager.shared.selectedSource {
case .modelScope, .modeler:
try await downloadFromModelScope(model, progress: progress)
case .huggingFace:
try await downloadFromHuggingFace(model, progress: progress)
}
}
private func downloadFromModelScope(_ model: ModelInfo,
progress: @escaping (Double) -> Void) async throws {
let ModelScopeId = model.modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 30
config.timeoutIntervalForResource = 300
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.name) { fileProgress in
progress(fileProgress)
}
}
private func downloadFromHuggingFace(_ model: ModelInfo,
progress: @escaping (Double) -> Void) async throws {
let repo = Hub.Repo(id: model.modelId)
let modelFiles = ["*.*"]
let mirrorHubApi = HubApi(endpoint: baseURL)
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
progress(fileProgress.fractionCompleted)
}
}
private func performRequest<T: Decodable>(url: URL, retries: Int = 3) async throws -> T {
var lastError: Error?
for attempt in 1...retries {
do {
var request = URLRequest(url: url)
request.setValue("application/json", forHTTPHeaderField: "Accept")
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw NetworkError.invalidResponse
}
if httpResponse.statusCode == 200 {
return try JSONDecoder().decode(T.self, from: data)
}
throw NetworkError.invalidResponse
} catch {
lastError = error
if attempt < retries {
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(attempt)) * 1_000_000_000))
continue
}
}
}
throw lastError ?? NetworkError.unknown
}
}
enum NetworkError: Error {
case invalidResponse
case invalidData
case downloadFailed
case unknown
}

View File

@ -1,202 +0,0 @@
//
// ModelListView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
struct ModelListView: View {
@State private var scrollOffset: CGFloat = 0
@State private var showHelp = false
@State private var showUserGuide = false
@State private var showHistory = false
@State private var selectedHistory: ChatHistory?
@State private var histories: [ChatHistory] = []
@State private var showSettings = false
@State private var showWebView = false
@State private var webViewURL: URL?
@StateObject private var viewModel = ModelListViewModel()
var body: some View {
ZStack {
NavigationView {
List {
SearchBar(text: $viewModel.searchText)
.listRowInsets(EdgeInsets())
.listRowSeparator(.hidden)
.padding(.horizontal)
ForEach(viewModel.filteredModels, id: \.modelId) { model in
ModelRowView(model: model,
downloadProgress: viewModel.downloadProgress[model.modelId] ?? 0,
isDownloading: viewModel.currentlyDownloading == model.modelId,
isOtherDownloading: viewModel.currentlyDownloading != nil) {
if model.isDownloaded {
viewModel.selectModel(model)
} else {
Task {
await viewModel.downloadModel(model)
}
}
}
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
if model.isDownloaded {
Button(role: .destructive) {
Task {
await viewModel.deleteModel(model)
}
} label: {
Label("Delete", systemImage: "trash")
}
}
}
}
}
.listStyle(.plain)
.navigationTitle("Models")
.navigationBarTitleDisplayMode(.large)
.navigationBarItems(
leading: Button(action: {
showHistory.toggle()
updateHistory()
}) {
Image(systemName: "clock.arrow.circlepath")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 22, height: 22)
},
trailing: settingsButton
)
.sheet(isPresented: $showHelp) {
HelpView()
}
.sheet(isPresented: $showWebView) {
if let url = webViewURL {
WebView(url: url)
}
}
.refreshable {
await viewModel.fetchModels()
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK", role: .cancel) {}
} message: {
Text(viewModel.errorMessage)
}
.background(
NavigationLink(
destination: {
if let selectedModel = viewModel.selectedModel {
return AnyView(LLMChatView(modelInfo: selectedModel))
} else if let selectedHistory = selectedHistory {
return AnyView(LLMChatView(modelInfo: ModelInfo(
modelId: selectedHistory.modelId,
createdAt: selectedHistory.createdAt.formatAgo(),
downloads: 0,
tags: [],
isDownloaded: true
), history: selectedHistory))
}
return AnyView(EmptyView())
}(),
isActive: Binding(
get: { viewModel.selectedModel != nil || selectedHistory != nil },
set: { if !$0 { viewModel.selectedModel = nil; selectedHistory = nil } }
)
) {
EmptyView()
}
)
.onAppear {
checkFirstLaunch()
}
.alert(isPresented: $showUserGuide) {
Alert(
title: Text("User Guide"),
message: Text("""
This is a local large model application that requires certain performance from your device.
It is recommended to choose different model sizes based on your device's memory.
The model recommendations for iPhone are as follows:
- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).
- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).
- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).
Choosing a model that is too large may cause insufficient memory and crashes.
"""),
dismissButton: .default(Text("OK"))
)
}
}
.disabled(showHistory)
if showHistory {
Color.black.opacity(0.5)
.edgesIgnoringSafeArea(.all)
.onTapGesture {
withAnimation {
showHistory = false
}
}
}
SideMenuView(isOpen: $showHistory, selectedHistory: $selectedHistory, histories: $histories)
.edgesIgnoringSafeArea(.all)
}
.onAppear {
updateHistory()
}
.actionSheet(isPresented: $showSettings) {
ActionSheet(title: Text("Settings"), buttons: [
.default(Text("Report an Issue")) {
webViewURL = URL(string: "https://github.com/alibaba/MNN/issues")
showWebView = true
},
.default(Text("Go to MNN Homepage")) {
webViewURL = URL(string: "https://github.com/alibaba/MNN")
showWebView = true
},
.default(Text(ModelSource.modelScope.description)) {
ModelSourceManager.shared.updateSelectedSource(.modelScope)
},
.default(Text(ModelSource.modeler.description)) {
ModelSourceManager.shared.updateSelectedSource(.modeler)
},
.default(Text(ModelSource.huggingFace.description)) {
ModelSourceManager.shared.updateSelectedSource(.huggingFace)
},
.cancel()
])
}
}
private func updateHistory() {
histories = ChatHistoryManager.shared.getAllHistory()
}
private func checkFirstLaunch() {
let hasLaunchedBefore = UserDefaults.standard.bool(forKey: "hasLaunchedBefore")
if !hasLaunchedBefore {
// Show the user guide alert
showUserGuide = true
// Set the flag to true so it doesn't show again
UserDefaults.standard.set(true, forKey: "hasLaunchedBefore")
}
}
private var settingsButton: some View {
Button(action: {
showSettings.toggle()
}) {
Image(systemName: "gear")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 22, height: 22)
}
}
}

View File

@ -1,60 +0,0 @@
//
// ModelRowView.swift
// MNNLLMiOS
//
// Created by () on 2025/1/3.
//
import SwiftUI
struct ModelRowView: View {
let model: ModelInfo
let downloadProgress: Double
let isDownloading: Bool
let isOtherDownloading: Bool
let onDownload: () -> Void
var body: some View {
HStack(alignment: .top) {
ModelIconView(modelId: model.modelId)
.frame(width: 50, height: 50)
VStack(alignment: .leading, spacing: 8) {
Text(model.name)
.font(.headline)
.lineLimit(1)
if !model.tags.isEmpty {
ScrollView(.horizontal, showsIndicators: false) {
HStack {
ForEach(model.tags, id: \.self) { tag in
Text(tag)
.font(.caption)
.padding(.horizontal, 8)
.padding(.vertical, 4)
.background(Color.blue.opacity(0.1))
.cornerRadius(8)
}
}
}
}
if isDownloading {
ProgressView(value: downloadProgress) {
Text(String(format: "%.2f%%", downloadProgress * 100))
.font(.system(size: 14, weight: .regular, design: .default))
}
} else {
Button(action: onDownload) {
Label(model.isDownloaded ? "Chat" : "Download",
systemImage: model.isDownloaded ? "message" : "arrow.down.circle")
.font(.system(size: 14, weight: .medium, design: .default))
}
.disabled(isOtherDownloading)
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,96 @@
//
// SettingsView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/1.
//
import SwiftUI
struct SettingsView: View {
private var sourceManager = ModelSourceManager.shared
@State private var selectedLanguage = ""
@State private var selectedSource = ModelSourceManager.shared.selectedSource
@State private var showLanguageAlert = false
private let languageOptions = LanguageManager.shared.languageOptions
var body: some View {
List {
Section(header: Text("settings.section.application")) {
Picker("settings.picker.downloadSource", selection: $selectedSource) {
ForEach(ModelSource.allCases, id: \.self) { source in
Text(source.rawValue).tag(source)
}
}
.onChange(of: selectedSource) { _, newValue in
sourceManager.updateSelectedSource(newValue)
}
Picker("settings.picker.language", selection: $selectedLanguage) {
ForEach(languageOptions.keys.sorted(), id: \.self) { key in
Text(languageOptions[key] ?? "").tag(key)
}
}
.onChange(of: selectedLanguage) { _, newValue in
if newValue != LanguageManager.shared.currentLanguage {
showLanguageAlert = true
}
}
}
Section(header: Text("settings.section.about")) {
Button(action: {
if let url = URL(string: "https://github.com/alibaba/MNN") {
UIApplication.shared.open(url)
}
}) {
HStack {
Text("settings.button.aboutMNN")
Spacer()
Image(systemName: "chevron.right")
.foregroundColor(.gray)
.font(.system(size: 14))
}
.foregroundColor(.primary)
}
Button(action: {
if let url = URL(string: "https://github.com/alibaba/MNN") {
UIApplication.shared.open(url)
}
}) {
HStack {
Text("settings.button.reportIssue")
Spacer()
Image(systemName: "chevron.right")
.foregroundColor(.gray)
.font(.system(size: 14))
}
.foregroundColor(.primary)
}
}
}
.listStyle(InsetGroupedListStyle())
.navigationTitle("settings.navigation.title")
.navigationBarTitleDisplayMode(.inline)
.alert("settings.alert.switchLanguage.title", isPresented: $showLanguageAlert) {
Button("settings.alert.switchLanguage.confirm") {
LanguageManager.shared.applyLanguage(selectedLanguage)
//
exit(0)
}
Button("settings.alert.switchLanguage.cancel", role: .cancel) {
//
selectedLanguage = LanguageManager.shared.currentLanguage
}
} message: {
Text("settings.alert.switchLanguage.message")
}
.onAppear {
selectedLanguage = LanguageManager.shared.currentLanguage
}
}
}

View File

@ -12,4 +12,19 @@ extension Color {
static var customBlue = Color(hex: "4859FD")
static var customPickerBg = Color(hex: "2F2F2F")
static var customLightPink = Color(hex: "E3E3E3")
static var primaryPurple = Color(hex: "4252B6")
static var primaryBlue = Color(hex: "2E97F2")
static var primaryRed = Color(hex: "D16D6A")
// Enhanced colors for benchmark UI
static var benchmarkGradientStart = Color(hex: "667eea")
static var benchmarkGradientEnd = Color(hex: "764ba2")
static var benchmarkCardBg = Color(hex: "FFFFFF")
static var benchmarkAccent = Color(hex: "6366f1")
static var benchmarkSuccess = Color(hex: "10b981")
static var benchmarkWarning = Color(hex: "f59e0b")
static var benchmarkError = Color(hex: "ef4444")
static var benchmarkSecondary = Color(hex: "6b7280")
static var benchmarkLight = Color(hex: "f8fafc")
}

Some files were not shown because too many files have changed in this diff Show More