mirror of https://github.com/alibaba/MNN.git
Compare commits
72 Commits
79146c6bd1
...
1e15bbf17b
Author | SHA1 | Date |
---|---|---|
|
1e15bbf17b | |
|
e814142254 | |
|
d32415dedc | |
|
98381cebf3 | |
|
5dd4338224 | |
|
e35f34b422 | |
|
78a40df272 | |
|
289504c9a8 | |
|
655e5dd884 | |
|
21dc12964a | |
|
17f2058295 | |
|
a105e0479b | |
|
f447eecbdd | |
|
81e34a22ac | |
|
e5ffe41f82 | |
|
5914a24d89 | |
|
1211e0591c | |
|
5873661270 | |
|
533ea03857 | |
|
25f24ee301 | |
|
e652145b40 | |
|
1488c1df0f | |
|
8535a6b3a9 | |
|
6cb8b29c4a | |
|
a0f867ef45 | |
|
73451790d3 | |
|
b56b8fc9ed | |
|
61e33fc603 | |
|
f3e3b19f2a | |
|
85855f7dbc | |
|
3d2091bc24 | |
|
cbe027a005 | |
|
b2050e0513 | |
|
ff5362a85a | |
|
b9d9464d53 | |
|
c7cf93ae24 | |
|
a11ec7b585 | |
|
6f8c28175e | |
|
11d1ee8283 | |
|
c4a48c01f9 | |
|
9f36b9a5c3 | |
|
5ec4211526 | |
|
0e7d0bf4e3 | |
|
cb7b28a4b1 | |
|
1125efd8c7 | |
|
5d170ed77a | |
|
4af8cfe6b7 | |
|
16032b9164 | |
|
59ebea5c03 | |
|
31a2825abb | |
|
661a62778a | |
|
bae99d5e68 | |
|
df333dd071 | |
|
814427eecf | |
|
a5579b506a | |
|
e20268dffc | |
|
8b641f61bb | |
|
b7c7f097c3 | |
|
39457666a0 | |
|
5123077a18 | |
|
25ea423f2f | |
|
cbed8ff5d2 | |
|
815f1a0548 | |
|
fc028886ed | |
|
6a86cdef5b | |
|
7719d75f03 | |
|
e426977f74 | |
|
84ef9956ea | |
|
fab1b50add | |
|
9aac4c29c3 | |
|
018ff22032 | |
|
569d23a0ad |
|
@ -35,10 +35,10 @@ jobs:
|
|||
cd project/android
|
||||
mkdir build_64
|
||||
cd build_64
|
||||
../build_64.sh
|
||||
../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true -DMNN_SUPPORT_BF16=ON
|
||||
- name: build_android_32
|
||||
run: |
|
||||
cd project/android
|
||||
mkdir build_32
|
||||
cd build_32
|
||||
../build_32.sh
|
||||
../build_32.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true
|
||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
- name: build
|
||||
run: |
|
||||
mkdir build && cd build
|
||||
cmake -DMNN_BUILD_TEST=ON ..
|
||||
cmake .. -DMNN_BUILD_TEST=ON -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true -DMNN_OPENCL=ON -DMNN_VULKAN=ON
|
||||
make -j4
|
||||
- name: test
|
||||
run: cd build && ./run_test.out
|
||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
- name: build
|
||||
run: |
|
||||
mkdir build && cd build
|
||||
cmake -DMNN_BUILD_TEST=ON ..
|
||||
cmake .. -DMNN_BUILD_TEST=ON -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true -DMNN_OPENCL=ON -DMNN_VULKAN=ON
|
||||
make -j4
|
||||
- name: test
|
||||
run: cd build && ./run_test.out
|
||||
|
|
|
@ -82,7 +82,7 @@ jobs:
|
|||
submodules: true
|
||||
|
||||
- name: build
|
||||
run: ./package_scripts/mac/buildFrameWork.sh
|
||||
run: ./package_scripts/mac/buildFrameWork.sh -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true
|
||||
- name: package
|
||||
run: |
|
||||
rm -rf ${{ env.PACKAGENAME }}
|
||||
|
@ -129,12 +129,12 @@ jobs:
|
|||
- name: build
|
||||
run: |
|
||||
brew install coreutils
|
||||
./package_scripts/ios/xcodebuildiOS.sh -o ios_build
|
||||
./package_scripts/ios/buildiOS.sh -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true
|
||||
|
||||
- name: package
|
||||
run: |
|
||||
rm -f ${{ env.PACKAGENAME }}.zip
|
||||
zip -9 -y -r ${{ env.PACKAGENAME }}.zip ios_build/Release-iphoneos/MNN.framework
|
||||
zip -9 -y -r ${{ env.PACKAGENAME }}.zip MNN-iOS-CPU-GPU/Static/MNN.framework
|
||||
- name: upload-zip
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
@ -161,4 +161,4 @@ jobs:
|
|||
with:
|
||||
file: assert/*.zip
|
||||
tags: true
|
||||
draft: true
|
||||
draft: true
|
||||
|
|
|
@ -30,7 +30,7 @@ jobs:
|
|||
- name: build
|
||||
run: |
|
||||
cd pymnn/pip_package
|
||||
python3 build_deps.py
|
||||
python3 build_deps.py llm
|
||||
sudo python3 setup.py install --version 1.0
|
||||
- name: test
|
||||
run: |
|
||||
|
|
|
@ -33,7 +33,7 @@ jobs:
|
|||
- name: build
|
||||
run: |
|
||||
cd pymnn/pip_package
|
||||
python3 build_deps.py
|
||||
python3 build_deps.py llm
|
||||
python3 setup.py install --version 1.0
|
||||
- name: test
|
||||
run: |
|
||||
|
|
|
@ -34,7 +34,7 @@ jobs:
|
|||
- name: build
|
||||
run: |
|
||||
cd pymnn/pip_package
|
||||
python3 build_deps.py
|
||||
python3 build_deps.py llm
|
||||
python3 setup.py install --version 1.0
|
||||
- name: test
|
||||
run: |
|
||||
|
|
|
@ -31,5 +31,5 @@ jobs:
|
|||
- name: test
|
||||
run: |
|
||||
mkdir build && cd build
|
||||
cmake -DMNN_BUILD_TEST=ON ..
|
||||
cmake .. -DMNN_BUILD_TEST=ON -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true -DMNN_VULKAN=ON -DMNN_OPENCL=ON
|
||||
cmake --build . -j4
|
||||
|
|
|
@ -376,3 +376,5 @@ datasets/*
|
|||
|
||||
# qnn 3rdParty
|
||||
source/backend/qnn/3rdParty/include
|
||||
apps/iOS/MNNLLMChat/Chat
|
||||
apps/iOS/MNNLLMChat/swift-transformers
|
||||
|
|
|
@ -52,7 +52,7 @@
|
|||
isa = PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet;
|
||||
buildPhase = 3E8591FA2D1D45070067B46F /* Sources */;
|
||||
membershipExceptions = (
|
||||
LLMWrapper/DiffusionSession.h,
|
||||
InferenceEngine/DiffusionSession.h,
|
||||
);
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"images" : [
|
||||
{
|
||||
"filename" : "star.png",
|
||||
"idiom" : "universal",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"scale" : "3x"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
Binary file not shown.
After Width: | Height: | Size: 3.6 KiB |
|
@ -63,7 +63,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
|
|||
}
|
||||
|
||||
Task {
|
||||
var status: Message.Status = .sending
|
||||
let status: Message.Status = .sending
|
||||
|
||||
var sender: LLMChatUser
|
||||
switch userType {
|
||||
|
@ -74,13 +74,15 @@ final class LLMChatInteractor: ChatInteractorProtocol {
|
|||
case .system:
|
||||
sender = chatData.system
|
||||
}
|
||||
var message: LLMChatMessage = await draftMessage.toLLMChatMessage(
|
||||
let message: LLMChatMessage = await draftMessage.toLLMChatMessage(
|
||||
id: UUID().uuidString,
|
||||
user: sender,
|
||||
status: status)
|
||||
|
||||
DispatchQueue.main.async { [weak self] in
|
||||
|
||||
// PerformanceMonitor.shared.recordUIUpdate()
|
||||
|
||||
switch userType {
|
||||
case .user, .system:
|
||||
self?.chatState.value.append(message)
|
||||
|
@ -97,18 +99,24 @@ final class LLMChatInteractor: ChatInteractorProtocol {
|
|||
|
||||
case .assistant:
|
||||
|
||||
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
|
||||
|
||||
if let isDeepSeek = self?.modelInfo.name.lowercased().contains("deepseek"), isDeepSeek == true,
|
||||
let text = self?.processor.process(progress: message.text) {
|
||||
updateLastMsg?.text = text
|
||||
} else {
|
||||
updateLastMsg?.text += message.text
|
||||
}
|
||||
|
||||
message.text = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1].text ?? ""
|
||||
|
||||
self?.chatState.value[(self?.chatState.value.count ?? 1) - 1] = updateLastMsg ?? message
|
||||
// PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
|
||||
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
|
||||
|
||||
if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
|
||||
let text = self?.processor.process(progress: message.text) {
|
||||
updateLastMsg?.text = text
|
||||
} else {
|
||||
if let currentText = updateLastMsg?.text {
|
||||
updateLastMsg?.text = currentText + message.text
|
||||
} else {
|
||||
updateLastMsg?.text = message.text
|
||||
}
|
||||
}
|
||||
|
||||
if let updatedMsg = updateLastMsg {
|
||||
self?.chatState.value[(self?.chatState.value.count ?? 1) - 1] = updatedMsg
|
||||
}
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,13 +25,13 @@ final class LLMChatData {
|
|||
|
||||
self.assistant = LLMChatUser(
|
||||
uid: "2",
|
||||
name: modelInfo.name,
|
||||
name: modelInfo.modelName,
|
||||
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
||||
)
|
||||
|
||||
self.system = LLMChatUser(
|
||||
uid: "0",
|
||||
name: modelInfo.name,
|
||||
name: modelInfo.modelName,
|
||||
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
||||
)
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ actor LLMState {
|
|||
return isProcessing
|
||||
}
|
||||
|
||||
func processContent(_ content: String, llm: LLMInferenceEngineWrapper?, completion: @escaping (String) -> Void) {
|
||||
llm?.processInput(content, withOutput: completion)
|
||||
func processContent(_ content: String, llm: LLMInferenceEngineWrapper?, showPerformance: Bool, completion: @escaping (String) -> Void) {
|
||||
llm?.processInput(content, withOutput: completion, showPerformance: true)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
//
|
||||
// PerformanceMonitor.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import UIKit
|
||||
|
||||
/**
|
||||
* PerformanceMonitor - A singleton utility for monitoring and measuring UI performance
|
||||
*
|
||||
* This class provides real-time performance monitoring capabilities to help identify
|
||||
* UI update bottlenecks, frame drops, and slow operations in iOS applications.
|
||||
* It's particularly useful during development to ensure smooth user experience.
|
||||
*
|
||||
* Key Features:
|
||||
* - Real-time FPS monitoring and frame drop detection
|
||||
* - UI update lag detection with customizable thresholds
|
||||
* - Execution time measurement for specific operations
|
||||
* - Automatic performance statistics reporting
|
||||
* - Thread-safe singleton implementation
|
||||
*
|
||||
* Usage Examples:
|
||||
*
|
||||
* 1. Monitor UI Updates:
|
||||
* ```swift
|
||||
* // Call this in your UI update methods
|
||||
* PerformanceMonitor.shared.recordUIUpdate()
|
||||
* ```
|
||||
*
|
||||
* 2. Measure Operation Performance:
|
||||
* ```swift
|
||||
* let result = PerformanceMonitor.shared.measureExecutionTime(operation: "Data Processing") {
|
||||
* // Your expensive operation here
|
||||
* return processLargeDataSet()
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* 3. Integration in ViewModels:
|
||||
* ```swift
|
||||
* func updateUI() {
|
||||
* PerformanceMonitor.shared.recordUIUpdate()
|
||||
* // Your UI update code
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* Performance Thresholds:
|
||||
* - Target FPS: 60 FPS
|
||||
* - Frame threshold: 25ms (1.5x normal frame time)
|
||||
* - Slow operation threshold: 16ms (1 frame time)
|
||||
*/
|
||||
class PerformanceMonitor {
|
||||
static let shared = PerformanceMonitor()
|
||||
|
||||
private var lastUpdateTime: Date = Date()
|
||||
private var updateCount: Int = 0
|
||||
private var frameDropCount: Int = 0
|
||||
private let targetFPS: Double = 60.0
|
||||
private let frameThreshold: TimeInterval = 1.0 / 60.0 * 1.5 // Allow 1.5x normal frame time
|
||||
|
||||
private init() {}
|
||||
|
||||
/**
|
||||
* Records a UI update event and monitors performance metrics
|
||||
*
|
||||
* Call this method whenever you perform UI updates to track performance.
|
||||
* It automatically detects frame drops and calculates FPS statistics.
|
||||
* Performance statistics are logged every second.
|
||||
*/
|
||||
func recordUIUpdate() {
|
||||
let currentTime = Date()
|
||||
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
|
||||
|
||||
updateCount += 1
|
||||
|
||||
// Detect frame drops
|
||||
if timeDiff > frameThreshold {
|
||||
frameDropCount += 1
|
||||
print("⚠️ UI Update Lag detected: \(timeDiff * 1000)ms (expected: \(frameThreshold * 1000)ms)")
|
||||
}
|
||||
|
||||
// Report statistics every second
|
||||
if timeDiff >= 1.0 {
|
||||
let actualFPS = Double(updateCount) / timeDiff
|
||||
let dropRate = Double(frameDropCount) / Double(updateCount) * 100
|
||||
|
||||
print("📊 Performance Stats - FPS: \(String(format: "%.1f", actualFPS)), Drop Rate: \(String(format: "%.1f", dropRate))%")
|
||||
|
||||
// Reset counters for next measurement cycle
|
||||
updateCount = 0
|
||||
frameDropCount = 0
|
||||
lastUpdateTime = currentTime
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Measures execution time for a specific operation
|
||||
*
|
||||
* Wraps any operation and measures its execution time. Operations taking
|
||||
* longer than 16ms (1 frame time) are logged as slow operations.
|
||||
*
|
||||
* - Parameters:
|
||||
* - operation: A descriptive name for the operation being measured
|
||||
* - block: The operation to measure
|
||||
* - Returns: The result of the operation
|
||||
* - Throws: Re-throws any error thrown by the operation
|
||||
*/
|
||||
func measureExecutionTime<T>(operation: String, block: () throws -> T) rethrows -> T {
|
||||
let startTime = CFAbsoluteTimeGetCurrent()
|
||||
let result = try block()
|
||||
let executionTime = CFAbsoluteTimeGetCurrent() - startTime
|
||||
|
||||
if executionTime > 0.016 { // Over 16ms (1 frame time)
|
||||
print("⏱️ Slow Operation: \(operation) took \(String(format: "%.3f", executionTime * 1000))ms")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
# UI Performance Optimization Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This document explains the performance optimization utilities implemented in the chat system to ensure smooth AI streaming text output and overall UI responsiveness.
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. PerformanceMonitor
|
||||
|
||||
A singleton utility for real-time performance monitoring and measurement.
|
||||
|
||||
#### Implementation Principles
|
||||
|
||||
- **Real-time FPS Monitoring**: Tracks UI update frequency and calculates actual FPS
|
||||
- **Frame Drop Detection**: Identifies when UI updates exceed the 16.67ms threshold (60 FPS)
|
||||
- **Operation Time Measurement**: Measures execution time of specific operations
|
||||
- **Automatic Statistics Reporting**: Logs performance metrics every second
|
||||
|
||||
#### Key Features
|
||||
|
||||
```swift
|
||||
class PerformanceMonitor {
|
||||
static let shared = PerformanceMonitor()
|
||||
|
||||
// Performance thresholds
|
||||
private let targetFPS: Double = 60.0
|
||||
private let frameThreshold: TimeInterval = 1.0 / 60.0 * 1.5 // 25ms threshold
|
||||
|
||||
func recordUIUpdate() { /* Track UI updates */ }
|
||||
func measureExecutionTime<T>(operation: String, block: () throws -> T) { /* Measure operations */ }
|
||||
}
|
||||
```
|
||||
|
||||
#### Usage in the Chat System
|
||||
|
||||
- Integrated into `LLMChatInteractor` to monitor message updates
|
||||
- Tracks UI update frequency during AI text streaming
|
||||
- Identifies performance bottlenecks in real-time
|
||||
|
||||
### 2. UIUpdateOptimizer
|
||||
|
||||
An actor-based utility for batching and throttling UI updates during streaming scenarios.
|
||||
|
||||
#### Implementation Principles
|
||||
|
||||
- **Batching Mechanism**: Groups multiple small updates into larger, more efficient ones
|
||||
- **Time-based Throttling**: Limits update frequency to prevent UI overload
|
||||
- **Actor-based Thread Safety**: Ensures safe concurrent access to update queue
|
||||
- **Automatic Flush Strategy**: Intelligently decides when to apply batched updates
|
||||
|
||||
#### Architecture
|
||||
|
||||
```swift
|
||||
actor UIUpdateOptimizer {
|
||||
static let shared = UIUpdateOptimizer()
|
||||
|
||||
private var pendingUpdates: [String] = []
|
||||
private let batchSize: Int = 5 // Batch threshold
|
||||
private let flushInterval: TimeInterval = 0.03 // 30ms throttling
|
||||
|
||||
func addUpdate(_ content: String, completion: @escaping (String) -> Void)
|
||||
func forceFlush(completion: @escaping (String) -> Void)
|
||||
}
|
||||
```
|
||||
|
||||
#### Optimization Strategies
|
||||
|
||||
1. **Batch Size Control**: Groups up to 5 updates before flushing
|
||||
2. **Time-based Throttling**: Flushes updates every 30ms maximum
|
||||
3. **Intelligent Scheduling**: Cancels redundant flush operations
|
||||
4. **Main Thread Delegation**: Ensures UI updates occur on the main thread
|
||||
|
||||
#### Integration Points
|
||||
|
||||
- **LLM Streaming**: Optimizes real-time text output from AI models
|
||||
- **Message Updates**: Batches frequent message content changes
|
||||
- **Force Flush**: Ensures final content is displayed when streaming ends
|
||||
|
||||
## Performance Optimization Flow
|
||||
|
||||
```
|
||||
AI Model Output → UIUpdateOptimizer → Batched Updates → UI Thread → Display
|
||||
↓
|
||||
PerformanceMonitor (Monitoring)
|
||||
↓
|
||||
Console Logs (Metrics)
|
||||
```
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
### Performance Metrics
|
||||
|
||||
1. **Target Performance**:
|
||||
- Maintain 50+ FPS during streaming
|
||||
- Keep frame drop rate below 5%
|
||||
- Single operations under 16ms
|
||||
|
||||
2. **Monitoring Indicators**:
|
||||
- `📊 Performance Stats` - Real-time FPS and drop rate
|
||||
- `⚠️ UI Update Lag detected` - Frame drop warnings
|
||||
- `⏱️ Slow Operation` - Operation time alerts
|
||||
|
||||
### Testing Methodology
|
||||
|
||||
1. **Streaming Tests**:
|
||||
- Test with long-form AI responses (articles, code)
|
||||
- Monitor console output for performance warnings
|
||||
- Observe visual smoothness of text animation
|
||||
|
||||
2. **Load Testing**:
|
||||
- Rapid successive message sending
|
||||
- Large text blocks processing
|
||||
- Multiple concurrent operations
|
||||
|
||||
3. **Comparative Analysis**:
|
||||
- Before/after optimization measurements
|
||||
- Different device performance profiles
|
||||
- Various content types and sizes
|
||||
|
||||
### Debug Configuration
|
||||
|
||||
For development and testing purposes:
|
||||
|
||||
```swift
|
||||
// Example configuration adjustments (not implemented in production)
|
||||
// UIUpdateOptimizer.shared.batchSize = 10
|
||||
// UIUpdateOptimizer.shared.flushInterval = 0.05
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### UIUpdateOptimizer Algorithm
|
||||
|
||||
1. **Add Update**: New content is appended to pending queue
|
||||
2. **Threshold Check**: Evaluate if immediate flush is needed
|
||||
- Batch size reached (≥5 updates)
|
||||
- Time threshold exceeded (≥30ms since last flush)
|
||||
3. **Scheduling**: If not immediate, schedule delayed flush
|
||||
4. **Flush Execution**: Combine all pending updates and execute on main thread
|
||||
5. **Cleanup**: Clear queue and reset timing
|
||||
|
||||
### PerformanceMonitor Algorithm
|
||||
|
||||
1. **Update Recording**: Track each UI update call
|
||||
2. **Timing Analysis**: Calculate time difference between updates
|
||||
3. **Frame Drop Detection**: Compare against 25ms threshold
|
||||
4. **Statistics Calculation**: Compute FPS and drop rate every second
|
||||
5. **Logging**: Output performance metrics to console
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### In ViewModels
|
||||
|
||||
```swift
|
||||
func updateUI() {
|
||||
PerformanceMonitor.shared.recordUIUpdate()
|
||||
// UI update code here
|
||||
}
|
||||
|
||||
let result = PerformanceMonitor.shared.measureExecutionTime(operation: "Data Processing") {
|
||||
return processLargeDataSet()
|
||||
}
|
||||
```
|
||||
|
||||
### In Streaming Scenarios
|
||||
|
||||
```swift
|
||||
await UIUpdateOptimizer.shared.addUpdate(newText) { batchedContent in
|
||||
// Update UI with optimized batched content
|
||||
updateTextView(with: batchedContent)
|
||||
}
|
||||
|
||||
// When stream ends
|
||||
await UIUpdateOptimizer.shared.forceFlush { finalContent in
|
||||
finalizeTextDisplay(with: finalContent)
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Performance Issues
|
||||
|
||||
1. **High Frame Drop Rate**:
|
||||
- Check for blocking operations on main thread
|
||||
- Verify batch size configuration
|
||||
- Monitor memory usage
|
||||
|
||||
2. **Slow Operation Warnings**:
|
||||
- Profile specific operations causing delays
|
||||
- Consider background threading for heavy tasks
|
||||
- Optimize data processing algorithms
|
||||
|
||||
3. **Inconsistent Performance**:
|
||||
- Check device thermal state
|
||||
- Monitor memory pressure
|
||||
- Verify background app activity
|
||||
|
||||
### Diagnostic Tools
|
||||
|
||||
- **Console Monitoring**: Watch for performance log messages
|
||||
- **Xcode Instruments**: Use Time Profiler for detailed analysis
|
||||
- **Memory Graph**: Check for memory leaks affecting performance
|
||||
- **Energy Impact**: Monitor battery and thermal effects
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Proactive Monitoring**: Always call `recordUIUpdate()` for critical UI operations
|
||||
2. **Batch When Possible**: Use `UIUpdateOptimizer` for frequent updates
|
||||
3. **Measure Critical Paths**: Wrap expensive operations with `measureExecutionTime`
|
||||
4. **Test on Real Devices**: Performance varies significantly across device types
|
||||
5. **Monitor in Production**: Keep performance logging enabled during development
|
||||
|
||||
This performance optimization system ensures smooth user experience during AI text generation while providing developers with the tools needed to maintain and improve performance over time.
|
|
@ -0,0 +1,136 @@
|
|||
//
|
||||
// UIUpdateOptimizer.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/7.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* UIUpdateOptimizer - A utility for batching and throttling UI updates to improve performance
|
||||
*
|
||||
* This actor-based optimizer helps reduce the frequency of UI updates by batching multiple
|
||||
* updates together and applying throttling mechanisms. It's particularly useful for scenarios
|
||||
* like streaming text updates, real-time data feeds, or any situation where frequent UI
|
||||
* updates might cause performance issues.
|
||||
*
|
||||
* Key Features:
|
||||
* - Batches multiple updates into a single operation
|
||||
* - Applies time-based throttling to limit update frequency
|
||||
* - Thread-safe actor implementation
|
||||
* - Automatic flush mechanism for pending updates
|
||||
*
|
||||
* Usage Example:
|
||||
* ```swift
|
||||
* // For streaming text updates
|
||||
* await UIUpdateOptimizer.shared.addUpdate(newText) { batchedContent in
|
||||
* // Update UI with batched content
|
||||
* textView.text = batchedContent
|
||||
* }
|
||||
*
|
||||
* // Force flush remaining updates when stream ends
|
||||
* await UIUpdateOptimizer.shared.forceFlush { finalContent in
|
||||
* textView.text = finalContent
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* Configuration:
|
||||
* - batchSize: Number of updates to batch before triggering immediate flush (default: 5)
|
||||
* - flushInterval: Time interval in seconds between automatic flushes (default: 0.03s / 30ms)
|
||||
*/
|
||||
actor UIUpdateOptimizer {
|
||||
static let shared = UIUpdateOptimizer()
|
||||
|
||||
private var pendingUpdates: [String] = []
|
||||
private var lastFlushTime: Date = Date()
|
||||
private var flushTask: Task<Void, Never>?
|
||||
|
||||
// Configuration constants
|
||||
private let batchSize: Int = 5 // Batch size threshold for immediate flush
|
||||
private let flushInterval: TimeInterval = 0.03 // 30ms throttling interval
|
||||
|
||||
private init() {}
|
||||
|
||||
/**
|
||||
* Adds a content update to the pending queue
|
||||
*
|
||||
* Updates are either flushed immediately if batch size or time threshold is reached,
|
||||
* or scheduled for delayed flushing to optimize performance.
|
||||
*
|
||||
* - Parameters:
|
||||
* - content: The content string to add to the update queue
|
||||
* - completion: Callback executed with the batched content when flushed
|
||||
*/
|
||||
func addUpdate(_ content: String, completion: @escaping (String) -> Void) {
|
||||
pendingUpdates.append(content)
|
||||
|
||||
// Determine if immediate flush is needed based on batch size or time interval
|
||||
let shouldFlushImmediately = pendingUpdates.count >= batchSize ||
|
||||
Date().timeIntervalSince(lastFlushTime) >= flushInterval
|
||||
|
||||
if shouldFlushImmediately {
|
||||
flushUpdates(completion: completion)
|
||||
} else {
|
||||
// Schedule delayed flush to optimize performance
|
||||
scheduleFlush(completion: completion)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedules a delayed flush operation
|
||||
*
|
||||
* Cancels any existing scheduled flush and creates a new one to avoid
|
||||
* excessive flush operations while maintaining responsiveness.
|
||||
*
|
||||
* - Parameter completion: Callback to execute when flush occurs
|
||||
*/
|
||||
private func scheduleFlush(completion: @escaping (String) -> Void) {
|
||||
// Cancel previous scheduled flush to avoid redundant operations
|
||||
flushTask?.cancel()
|
||||
|
||||
flushTask = Task {
|
||||
try? await Task.sleep(nanoseconds: UInt64(flushInterval * 1_000_000_000))
|
||||
|
||||
if !Task.isCancelled && !pendingUpdates.isEmpty {
|
||||
flushUpdates(completion: completion)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flushes all pending updates immediately
|
||||
*
|
||||
* Combines all pending updates into a single string and executes the completion
|
||||
* callback on the main actor thread for UI updates.
|
||||
*
|
||||
* - Parameter completion: Callback executed with the combined content
|
||||
*/
|
||||
private func flushUpdates(completion: @escaping (String) -> Void) {
|
||||
guard !pendingUpdates.isEmpty else { return }
|
||||
|
||||
let batchedContent = pendingUpdates.joined()
|
||||
pendingUpdates.removeAll()
|
||||
lastFlushTime = Date()
|
||||
|
||||
Task { @MainActor in
|
||||
completion(batchedContent)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Forces immediate flush of any remaining pending updates
|
||||
*
|
||||
* This method should be called when you need to ensure all pending updates
|
||||
* are processed immediately, such as when a stream ends or the view is about
|
||||
* to disappear.
|
||||
*
|
||||
* - Parameter completion: Callback executed with any remaining content
|
||||
*/
|
||||
func forceFlush(completion: @escaping (String) -> Void) {
|
||||
if !pendingUpdates.isEmpty {
|
||||
flushUpdates(completion: completion)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,7 +11,6 @@ import AVFoundation
|
|||
import ExyteChat
|
||||
import ExyteMediaPicker
|
||||
|
||||
|
||||
final class LLMChatViewModel: ObservableObject {
|
||||
|
||||
private var llm: LLMInferenceEngineWrapper?
|
||||
|
@ -21,6 +20,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
@Published var messages: [Message] = []
|
||||
@Published var isModelLoaded = false
|
||||
@Published var isProcessing: Bool = false
|
||||
@Published var currentStreamingMessageId: String? = nil
|
||||
|
||||
@Published var useMmap: Bool = false
|
||||
|
||||
|
@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
let modelConfigManager: ModelConfigManager
|
||||
|
||||
var isDiffusionModel: Bool {
|
||||
return modelInfo.name.lowercased().contains("diffusion")
|
||||
return modelInfo.modelName.lowercased().contains("diffusion")
|
||||
}
|
||||
|
||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||
|
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
), userType: .system)
|
||||
}
|
||||
|
||||
if modelInfo.name.lowercased().contains("diffusion") {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
||||
Task { @MainActor in
|
||||
print("Diffusion Model \(success)")
|
||||
|
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
func sendToLLM(draft: DraftMessage) {
|
||||
self.send(draft: draft, userType: .user)
|
||||
if isModelLoaded {
|
||||
if modelInfo.name.lowercased().contains("diffusion") {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
self.getDiffusionResponse(draft: draft)
|
||||
} else {
|
||||
self.getLLMRespsonse(draft: draft)
|
||||
|
@ -166,9 +166,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
|
||||
Task {
|
||||
|
||||
let tempDir = FileManager.default.temporaryDirectory
|
||||
let imageName = UUID().uuidString + ".jpg"
|
||||
let tempImagePath = tempDir.appendingPathComponent(imageName).path
|
||||
let tempImagePath = FileOperationManager.shared.generateTempImagePath().path
|
||||
|
||||
var lastProcess:Int32 = 0
|
||||
|
||||
|
@ -198,7 +196,21 @@ final class LLMChatViewModel: ObservableObject {
|
|||
func getLLMRespsonse(draft: DraftMessage) {
|
||||
Task {
|
||||
await llmState.setProcessing(true)
|
||||
await MainActor.run { self.isProcessing = true }
|
||||
await MainActor.run {
|
||||
self.isProcessing = true
|
||||
let emptyMessage = DraftMessage(
|
||||
text: "",
|
||||
thinkText: "",
|
||||
medias: [],
|
||||
recording: nil,
|
||||
replyMessage: nil,
|
||||
createdAt: Date()
|
||||
)
|
||||
self.send(draft: emptyMessage, userType: .assistant)
|
||||
if let lastMessage = self.messages.last {
|
||||
self.currentStreamingMessageId = lastMessage.id
|
||||
}
|
||||
}
|
||||
|
||||
var content = draft.text
|
||||
let medias = draft.medias
|
||||
|
@ -209,18 +221,10 @@ final class LLMChatViewModel: ObservableObject {
|
|||
continue
|
||||
}
|
||||
|
||||
let isInTempDirectory = url.path.contains("/tmp/")
|
||||
let fileName = url.lastPathComponent
|
||||
|
||||
if !isInTempDirectory {
|
||||
guard let fileUrl = AssetExtractor.copyFileToTmpDirectory(from: url, fileName: fileName) else {
|
||||
continue
|
||||
}
|
||||
let processedUrl = convertHEICImage(from: fileUrl)
|
||||
content = "<img>\(processedUrl?.path ?? "")</img>" + content
|
||||
} else {
|
||||
let processedUrl = convertHEICImage(from: url)
|
||||
content = "<img>\(processedUrl?.path ?? "")</img>" + content
|
||||
if let processedUrl = FileOperationManager.shared.processImageFile(from: url, fileName: fileName) {
|
||||
content = "<img>\(processedUrl.path)</img>" + content
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -232,13 +236,36 @@ final class LLMChatViewModel: ObservableObject {
|
|||
|
||||
let convertedContent = self.convertDeepSeekMutliChat(content: content)
|
||||
|
||||
await llmState.processContent(convertedContent, llm: self.llm) { [weak self] output in
|
||||
await llmState.processContent(convertedContent, llm: self.llm, showPerformance: true) { [weak self] output in
|
||||
guard let self = self else { return }
|
||||
|
||||
if output.contains("<eop>") {
|
||||
// force flush
|
||||
Task {
|
||||
await UIUpdateOptimizer.shared.forceFlush { finalOutput in
|
||||
if !finalOutput.isEmpty {
|
||||
self.send(draft: DraftMessage(
|
||||
text: finalOutput,
|
||||
thinkText: "",
|
||||
medias: [],
|
||||
recording: nil,
|
||||
replyMessage: nil,
|
||||
createdAt: Date()
|
||||
), userType: .assistant)
|
||||
}
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.isProcessing = false
|
||||
self.currentStreamingMessageId = nil
|
||||
}
|
||||
await self.llmState.setProcessing(false)
|
||||
}
|
||||
return
|
||||
}
|
||||
Task { @MainActor in
|
||||
if (output.contains("<eop>")) {
|
||||
self?.isProcessing = false
|
||||
await self?.llmState.setProcessing(false)
|
||||
} else {
|
||||
self?.send(draft: DraftMessage(
|
||||
await UIUpdateOptimizer.shared.addUpdate(output) { output in
|
||||
self.send(draft: DraftMessage(
|
||||
text: output,
|
||||
thinkText: "",
|
||||
medias: [],
|
||||
|
@ -248,6 +275,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
), userType: .assistant)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -259,7 +287,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
}
|
||||
|
||||
private func convertDeepSeekMutliChat(content: String) -> String {
|
||||
if self.modelInfo.name.lowercased().contains("deepseek") {
|
||||
if self.modelInfo.modelName.lowercased().contains("deepseek") {
|
||||
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||
*/
|
||||
|
@ -286,14 +314,11 @@ final class LLMChatViewModel: ObservableObject {
|
|||
}
|
||||
}
|
||||
|
||||
private func convertHEICImage(from url: URL) -> URL? {
|
||||
var fileUrl = url
|
||||
if fileUrl.isHEICImage() {
|
||||
if let convertedUrl = AssetExtractor.convertHEICToJPG(heicUrl: fileUrl) {
|
||||
fileUrl = convertedUrl
|
||||
}
|
||||
}
|
||||
return fileUrl
|
||||
// MARK: - Public Methods for File Operations
|
||||
|
||||
/// Cleans the model temporary folder using FileOperationManager
|
||||
func cleanModelTmpFolder() {
|
||||
FileOperationManager.shared.cleanModelTempFolder(modelPath: modelInfo.localPath)
|
||||
}
|
||||
|
||||
func onStart() {
|
||||
|
@ -311,14 +336,18 @@ final class LLMChatViewModel: ObservableObject {
|
|||
func onStop() {
|
||||
ChatHistoryManager.shared.saveChat(
|
||||
historyId: historyId,
|
||||
modelId: modelInfo.modelId,
|
||||
modelName: modelInfo.name,
|
||||
modelId: modelInfo.id,
|
||||
modelName: modelInfo.modelName,
|
||||
messages: messages
|
||||
)
|
||||
|
||||
interactor.disconnect()
|
||||
llm = nil
|
||||
self.cleanTmpFolder()
|
||||
|
||||
FileOperationManager.shared.cleanTempDirectories()
|
||||
if !useMmap {
|
||||
FileOperationManager.shared.cleanModelTempFolder(modelPath: modelInfo.localPath)
|
||||
}
|
||||
}
|
||||
|
||||
func loadMoreMessage(before message: Message) {
|
||||
|
@ -326,40 +355,4 @@ final class LLMChatViewModel: ObservableObject {
|
|||
.sink { _ in }
|
||||
.store(in: &subscriptions)
|
||||
}
|
||||
|
||||
|
||||
func cleanModelTmpFolder() {
|
||||
let tmpFolderURL = URL(fileURLWithPath: self.modelInfo.localPath).appendingPathComponent("temp")
|
||||
self.cleanFolder(tmpFolderURL: tmpFolderURL)
|
||||
}
|
||||
|
||||
private func cleanTmpFolder() {
|
||||
let fileManager = FileManager.default
|
||||
let tmpDirectoryURL = fileManager.temporaryDirectory
|
||||
|
||||
self.cleanFolder(tmpFolderURL: tmpDirectoryURL)
|
||||
|
||||
if !useMmap {
|
||||
cleanModelTmpFolder()
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanFolder(tmpFolderURL: URL) {
|
||||
let fileManager = FileManager.default
|
||||
do {
|
||||
let files = try fileManager.contentsOfDirectory(at: tmpFolderURL, includingPropertiesForKeys: nil)
|
||||
for file in files {
|
||||
if !file.absoluteString.lowercased().contains("networkdownload") {
|
||||
do {
|
||||
try fileManager.removeItem(at: file)
|
||||
print("Deleted file: \(file.path)")
|
||||
} catch {
|
||||
print("Error deleting file: \(file.path), \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("Error accessing tmp directory: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,12 +17,14 @@ struct LLMChatView: View {
|
|||
private let title: String
|
||||
private let modelPath: String
|
||||
|
||||
private let recorderSettings = RecorderSettings(audioFormatID: kAudioFormatLinearPCM, sampleRate: 44100, numberOfChannels: 2, linearPCMBitDepth: 16)
|
||||
private let recorderSettings = RecorderSettings(audioFormatID: kAudioFormatLinearPCM,
|
||||
sampleRate: 44100, numberOfChannels: 2,
|
||||
linearPCMBitDepth: 16)
|
||||
|
||||
@State private var showSettings = false
|
||||
|
||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||
self.title = modelInfo.name
|
||||
self.title = modelInfo.modelName
|
||||
self.modelPath = modelInfo.localPath
|
||||
let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history)
|
||||
_viewModel = StateObject(wrappedValue: viewModel)
|
||||
|
@ -32,6 +34,15 @@ struct LLMChatView: View {
|
|||
ChatView(messages: viewModel.messages, chatType: .conversation) { draft in
|
||||
viewModel.sendToLLM(draft: draft)
|
||||
}
|
||||
messageBuilder: { message, positionInGroup, positionInCommentsGroup, showContextMenuClosure, messageActionClosure, showAttachmentClosure in
|
||||
LLMChatMessageView(
|
||||
message: message,
|
||||
positionInGroup: positionInGroup,
|
||||
showContextMenuClosure: showContextMenuClosure,
|
||||
messageActionClosure: messageActionClosure,
|
||||
showAttachmentClosure: showAttachmentClosure
|
||||
)
|
||||
}
|
||||
.setAvailableInput(
|
||||
self.title.lowercased().contains("vl") ? .textAndMedia :
|
||||
self.title.lowercased().contains("audio") ? .textAndAudio :
|
||||
|
@ -109,4 +120,28 @@ struct LLMChatView: View {
|
|||
}
|
||||
.onDisappear(perform: viewModel.onStop)
|
||||
}
|
||||
|
||||
// MARK: - LLM Chat Message Builder
|
||||
@ViewBuilder
|
||||
private func LLMChatMessageView(
|
||||
message: Message,
|
||||
positionInGroup: PositionInUserGroup,
|
||||
showContextMenuClosure: @escaping () -> Void,
|
||||
messageActionClosure: @escaping (Message, DefaultMessageMenuAction) -> Void,
|
||||
showAttachmentClosure: @escaping (Attachment) -> Void
|
||||
) -> some View {
|
||||
LLMMessageView(
|
||||
message: message,
|
||||
positionInGroup: positionInGroup,
|
||||
isAssistantMessage: !message.user.isCurrentUser,
|
||||
isStreamingMessage: viewModel.currentStreamingMessageId == message.id,
|
||||
showContextMenuClosure: {
|
||||
if !viewModel.isProcessing {
|
||||
showContextMenuClosure()
|
||||
}
|
||||
},
|
||||
messageActionClosure: messageActionClosure,
|
||||
showAttachmentClosure: showAttachmentClosure
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,338 @@
|
|||
//
|
||||
// LLMMessageTextView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/7.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
import MarkdownUI
|
||||
|
||||
/**
|
||||
* LLMMessageTextView - A specialized text view designed for LLM chat messages with typewriter animation
|
||||
*
|
||||
* This SwiftUI component provides an enhanced text display specifically designed for AI chat applications.
|
||||
* It supports both plain text and Markdown rendering with an optional typewriter animation effect
|
||||
* that creates a dynamic, engaging user experience during AI response streaming.
|
||||
*
|
||||
* Key Features:
|
||||
* - Typewriter animation for streaming AI responses
|
||||
* - Markdown support with custom styling
|
||||
* - Smart animation control based on message type and content length
|
||||
* - Automatic animation management with proper cleanup
|
||||
* - Performance-optimized character-by-character rendering
|
||||
*
|
||||
* Usage Examples:
|
||||
*
|
||||
* 1. Basic AI Message with Typewriter Effect:
|
||||
* ```swift
|
||||
* LLMMessageTextView(
|
||||
* text: "Hello! This is an AI response with typewriter animation.",
|
||||
* messageUseMarkdown: false,
|
||||
* messageId: "msg_001",
|
||||
* isAssistantMessage: true,
|
||||
* isStreamingMessage: true
|
||||
* )
|
||||
* ```
|
||||
*
|
||||
* 2. Markdown Message with Custom Styling:
|
||||
* ```swift
|
||||
* LLMMessageTextView(
|
||||
* text: "**Bold text** and *italic text* with `code blocks`",
|
||||
* messageUseMarkdown: true,
|
||||
* messageId: "msg_002",
|
||||
* isAssistantMessage: true,
|
||||
* isStreamingMessage: true
|
||||
* )
|
||||
* ```
|
||||
*
|
||||
* 3. User Message (No Animation):
|
||||
* ```swift
|
||||
* LLMMessageTextView(
|
||||
* text: "This is a user message",
|
||||
* messageUseMarkdown: false,
|
||||
* messageId: "msg_003",
|
||||
* isAssistantMessage: false,
|
||||
* isStreamingMessage: false
|
||||
* )
|
||||
* ```
|
||||
*
|
||||
* Animation Configuration:
|
||||
* - typingSpeed: 0.015 seconds per character (adjustable)
|
||||
* - chunkSize: 1 character per animation frame
|
||||
* - Minimum text length for animation: 5 characters
|
||||
* - Auto-cleanup on view disappear or streaming completion
|
||||
*/
|
||||
struct LLMMessageTextView: View {
|
||||
let text: String?
|
||||
let messageUseMarkdown: Bool
|
||||
let messageId: String
|
||||
let isAssistantMessage: Bool
|
||||
let isStreamingMessage: Bool // Whether this message is currently being streamed
|
||||
|
||||
@State private var displayedText: String = ""
|
||||
@State private var animationTimer: Timer?
|
||||
|
||||
// Typewriter animation configuration
|
||||
private let typingSpeed: TimeInterval = 0.015 // Time interval per character
|
||||
private let chunkSize: Int = 1 // Number of characters to display per frame
|
||||
|
||||
init(text: String?,
|
||||
messageUseMarkdown: Bool = false,
|
||||
messageId: String,
|
||||
isAssistantMessage: Bool = false,
|
||||
isStreamingMessage: Bool = false) {
|
||||
self.text = text
|
||||
self.messageUseMarkdown = messageUseMarkdown
|
||||
self.messageId = messageId
|
||||
self.isAssistantMessage = isAssistantMessage
|
||||
self.isStreamingMessage = isStreamingMessage
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
Group {
|
||||
if let text = text, !text.isEmpty {
|
||||
if isAssistantMessage && isStreamingMessage && shouldUseTypewriter {
|
||||
typewriterView(text)
|
||||
} else {
|
||||
staticView(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
.onAppear {
|
||||
if let text = text, isAssistantMessage && isStreamingMessage && shouldUseTypewriter {
|
||||
startTypewriterAnimation(for: text)
|
||||
} else if let text = text {
|
||||
displayedText = text
|
||||
}
|
||||
}
|
||||
.onDisappear {
|
||||
stopAnimation()
|
||||
}
|
||||
.onChange(of: text) { oldText, newText in
|
||||
handleTextChange(newText)
|
||||
}
|
||||
.onChange(of: isStreamingMessage) { oldIsStreaming, newIsStreaming in
|
||||
if !newIsStreaming {
|
||||
// Streaming ended, display complete text
|
||||
if let text = text {
|
||||
displayedText = text
|
||||
}
|
||||
stopAnimation()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines whether typewriter animation should be used
|
||||
*
|
||||
* Animation is enabled only for assistant messages with more than 5 characters
|
||||
* to avoid unnecessary animation for short responses.
|
||||
*/
|
||||
private var shouldUseTypewriter: Bool {
|
||||
// Enable typewriter effect only for assistant messages with sufficient length
|
||||
return isAssistantMessage && (text?.count ?? 0) > 5
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders text with typewriter animation effect
|
||||
*
|
||||
* - Parameter text: The complete text to be animated
|
||||
* - Returns: A view displaying the animated text with optional Markdown support
|
||||
*/
|
||||
@ViewBuilder
|
||||
private func typewriterView(_ text: String) -> some View {
|
||||
if messageUseMarkdown {
|
||||
Markdown(displayedText)
|
||||
.markdownBlockStyle(\.blockquote) { configuration in
|
||||
configuration.label
|
||||
.padding()
|
||||
.markdownTextStyle {
|
||||
FontSize(13)
|
||||
FontWeight(.light)
|
||||
BackgroundColor(nil)
|
||||
}
|
||||
.overlay(alignment: .leading) {
|
||||
Rectangle()
|
||||
.fill(Color.gray)
|
||||
.frame(width: 4)
|
||||
}
|
||||
.background(Color.gray.opacity(0.2))
|
||||
}
|
||||
} else {
|
||||
Text(displayedText)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders static text without animation
|
||||
*
|
||||
* - Parameter text: The text to be displayed
|
||||
* - Returns: A view displaying the complete text with optional Markdown support
|
||||
*/
|
||||
@ViewBuilder
|
||||
private func staticView(_ text: String) -> some View {
|
||||
if messageUseMarkdown {
|
||||
Markdown(text)
|
||||
.markdownBlockStyle(\.blockquote) { configuration in
|
||||
configuration.label
|
||||
.padding()
|
||||
.markdownTextStyle {
|
||||
FontSize(13)
|
||||
FontWeight(.light)
|
||||
BackgroundColor(nil)
|
||||
}
|
||||
.overlay(alignment: .leading) {
|
||||
Rectangle()
|
||||
.fill(Color.gray)
|
||||
.frame(width: 4)
|
||||
}
|
||||
.background(Color.gray.opacity(0.2))
|
||||
}
|
||||
} else {
|
||||
Text(text)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles text content changes during streaming
|
||||
*
|
||||
* This method intelligently manages animation continuation, restart, or direct display
|
||||
* based on the relationship between old and new text content.
|
||||
*
|
||||
* - Parameter newText: The updated text content
|
||||
*/
|
||||
private func handleTextChange(_ newText: String?) {
|
||||
guard let newText = newText else {
|
||||
displayedText = ""
|
||||
stopAnimation()
|
||||
return
|
||||
}
|
||||
|
||||
if isAssistantMessage && isStreamingMessage && shouldUseTypewriter {
|
||||
// Check if new text is an extension of current displayed text
|
||||
if newText.hasPrefix(displayedText) && newText != displayedText {
|
||||
// Continue typewriter animation
|
||||
continueTypewriterAnimation(with: newText)
|
||||
} else if newText != displayedText {
|
||||
// Restart animation with new content
|
||||
restartTypewriterAnimation(with: newText)
|
||||
}
|
||||
} else {
|
||||
// Display text directly without animation
|
||||
displayedText = newText
|
||||
stopAnimation()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiates typewriter animation for the given text
|
||||
*
|
||||
* - Parameter text: The text to animate
|
||||
*/
|
||||
private func startTypewriterAnimation(for text: String) {
|
||||
displayedText = ""
|
||||
continueTypewriterAnimation(with: text)
|
||||
}
|
||||
|
||||
/**
|
||||
* Continues or resumes typewriter animation
|
||||
*
|
||||
* This method sets up a timer-based animation that progressively reveals
|
||||
* characters at the configured typing speed.
|
||||
*
|
||||
* - Parameter text: The complete text to animate
|
||||
*/
|
||||
private func continueTypewriterAnimation(with text: String) {
|
||||
guard displayedText.count < text.count else { return }
|
||||
|
||||
stopAnimation()
|
||||
|
||||
animationTimer = Timer.scheduledTimer(withTimeInterval: typingSpeed, repeats: true) { timer in
|
||||
DispatchQueue.main.async {
|
||||
self.appendNextCharacters(from: text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts typewriter animation with new content
|
||||
*
|
||||
* - Parameter text: The new text to animate
|
||||
*/
|
||||
private func restartTypewriterAnimation(with text: String) {
|
||||
stopAnimation()
|
||||
displayedText = ""
|
||||
startTypewriterAnimation(for: text)
|
||||
}
|
||||
|
||||
/**
|
||||
* Appends the next character(s) to the displayed text
|
||||
*
|
||||
* This method is called by the animation timer to progressively reveal
|
||||
* text characters. It handles proper string indexing and animation completion.
|
||||
*
|
||||
* - Parameter text: The source text to extract characters from
|
||||
*/
|
||||
private func appendNextCharacters(from text: String) {
|
||||
let currentLength = displayedText.count
|
||||
guard currentLength < text.count else {
|
||||
stopAnimation()
|
||||
return
|
||||
}
|
||||
|
||||
let endIndex = min(currentLength + chunkSize, text.count)
|
||||
let startIndex = text.index(text.startIndex, offsetBy: currentLength)
|
||||
let targetIndex = text.index(text.startIndex, offsetBy: endIndex)
|
||||
|
||||
let newChars = text[startIndex..<targetIndex]
|
||||
displayedText.append(String(newChars))
|
||||
|
||||
if displayedText.count >= text.count {
|
||||
stopAnimation()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops and cleans up the typewriter animation
|
||||
*
|
||||
* This method should be called when animation is no longer needed
|
||||
* to prevent memory leaks and unnecessary timer execution.
|
||||
*/
|
||||
private func stopAnimation() {
|
||||
animationTimer?.invalidate()
|
||||
animationTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preview Provider
|
||||
struct LLMMessageTextView_Previews: PreviewProvider {
|
||||
static var previews: some View {
|
||||
VStack(spacing: 20) {
|
||||
LLMMessageTextView(
|
||||
text: "This is a typewriter animation demo text. Hello, this demonstrates the streaming effect!",
|
||||
messageUseMarkdown: false,
|
||||
messageId: "test1",
|
||||
isAssistantMessage: true,
|
||||
isStreamingMessage: true
|
||||
)
|
||||
|
||||
LLMMessageTextView(
|
||||
text: "**Bold text** and *italic text* with markdown support.",
|
||||
messageUseMarkdown: true,
|
||||
messageId: "test2",
|
||||
isAssistantMessage: true,
|
||||
isStreamingMessage: true
|
||||
)
|
||||
|
||||
LLMMessageTextView(
|
||||
text: "Regular user message without animation.",
|
||||
messageUseMarkdown: false,
|
||||
messageId: "test3",
|
||||
isAssistantMessage: false,
|
||||
isStreamingMessage: false
|
||||
)
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,275 @@
|
|||
//
|
||||
// LLMMessageView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/7.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
import Foundation
|
||||
import SwiftUI
|
||||
import ExyteChat
|
||||
|
||||
// MARK: - Custom Message View
|
||||
struct LLMMessageView: View {
|
||||
let message: Message
|
||||
let positionInGroup: PositionInUserGroup
|
||||
let isAssistantMessage: Bool
|
||||
let isStreamingMessage: Bool
|
||||
let showContextMenuClosure: () -> Void
|
||||
let messageActionClosure: (Message, DefaultMessageMenuAction) -> Void
|
||||
let showAttachmentClosure: (Attachment) -> Void
|
||||
|
||||
let theme = ChatTheme(
|
||||
colors: .init(
|
||||
messageMyBG: .customBlue.opacity(0.2),
|
||||
messageFriendBG: .clear
|
||||
),
|
||||
images: .init(
|
||||
attach: Image(systemName: "photo"),
|
||||
attachCamera: Image("attachCamera", bundle: .current)
|
||||
)
|
||||
)
|
||||
|
||||
@State var avatarViewSize: CGSize = .zero
|
||||
@State var timeSize: CGSize = .zero
|
||||
|
||||
static let widthWithMedia: CGFloat = 204
|
||||
static let horizontalNoAvatarPadding: CGFloat = 8
|
||||
static let horizontalAvatarPadding: CGFloat = 8
|
||||
static let horizontalTextPadding: CGFloat = 12
|
||||
static let horizontalAttachmentPadding: CGFloat = 1
|
||||
static let horizontalBubblePadding: CGFloat = 70
|
||||
|
||||
var additionalMediaInset: CGFloat {
|
||||
message.attachments.count > 1 ? LLMMessageView.horizontalAttachmentPadding * 2 : 0
|
||||
}
|
||||
|
||||
var showAvatar: Bool {
|
||||
positionInGroup == .single
|
||||
|| positionInGroup == .last
|
||||
}
|
||||
|
||||
var topPadding: CGFloat {
|
||||
positionInGroup == .single || positionInGroup == .first ? 8 : 4
|
||||
}
|
||||
|
||||
var bottomPadding: CGFloat {
|
||||
0
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .top, spacing: 0) {
|
||||
if !message.user.isCurrentUser {
|
||||
avatarView
|
||||
}
|
||||
|
||||
VStack(alignment: message.user.isCurrentUser ? .trailing : .leading, spacing: 2) {
|
||||
bubbleView(message)
|
||||
}
|
||||
}
|
||||
.padding(.top, topPadding)
|
||||
.padding(.bottom, bottomPadding)
|
||||
.padding(.trailing, message.user.isCurrentUser ? LLMMessageView.horizontalNoAvatarPadding : 0)
|
||||
.padding(message.user.isCurrentUser ? .leading : .trailing, message.user.isCurrentUser ? LLMMessageView.horizontalBubblePadding : 0)
|
||||
.frame(maxWidth: UIScreen.main.bounds.width, alignment: message.user.isCurrentUser ? .trailing : .leading)
|
||||
.contentShape(Rectangle())
|
||||
.onLongPressGesture {
|
||||
showContextMenuClosure()
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
func bubbleView(_ message: Message) -> some View {
|
||||
VStack(alignment: .leading, spacing: 0) {
|
||||
if !message.attachments.isEmpty {
|
||||
attachmentsView(message)
|
||||
}
|
||||
|
||||
if !message.text.isEmpty {
|
||||
textWithTimeView(message)
|
||||
}
|
||||
|
||||
if let recording = message.recording {
|
||||
VStack(alignment: .trailing, spacing: 8) {
|
||||
recordingView(recording)
|
||||
messageTimeView()
|
||||
.padding(.bottom, 8)
|
||||
.padding(.trailing, 12)
|
||||
}
|
||||
}
|
||||
}
|
||||
.bubbleBackground(message, theme: theme)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
var avatarView: some View {
|
||||
Group {
|
||||
if showAvatar {
|
||||
AsyncImage(url: message.user.avatarURL) { image in
|
||||
image
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fill)
|
||||
} placeholder: {
|
||||
Circle()
|
||||
.fill(Color.gray.opacity(0.3))
|
||||
}
|
||||
.frame(width: 32, height: 32)
|
||||
.clipShape(Circle())
|
||||
.contentShape(Circle())
|
||||
} else {
|
||||
Color.clear.frame(width: 32, height: 32)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, LLMMessageView.horizontalAvatarPadding)
|
||||
.sizeGetter($avatarViewSize)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
func attachmentsView(_ message: Message) -> some View {
|
||||
ForEach(message.attachments, id: \.id) { attachment in
|
||||
AsyncImage(url: attachment.thumbnail) { image in
|
||||
image
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fit)
|
||||
} placeholder: {
|
||||
Rectangle()
|
||||
.fill(Color.gray.opacity(0.3))
|
||||
}
|
||||
.frame(maxWidth: LLMMessageView.widthWithMedia, maxHeight: 200)
|
||||
.cornerRadius(12)
|
||||
.onTapGesture {
|
||||
showAttachmentClosure(attachment)
|
||||
}
|
||||
}
|
||||
.applyIf(message.attachments.count > 1) {
|
||||
$0
|
||||
.padding(.top, LLMMessageView.horizontalAttachmentPadding)
|
||||
.padding(.horizontal, LLMMessageView.horizontalAttachmentPadding)
|
||||
}
|
||||
.overlay(alignment: .bottomTrailing) {
|
||||
if message.text.isEmpty {
|
||||
messageTimeView(needsCapsule: true)
|
||||
.padding(4)
|
||||
}
|
||||
}
|
||||
.contentShape(Rectangle())
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
func textWithTimeView(_ message: Message) -> some View {
|
||||
// Message View with Type Writer Animation
|
||||
let messageView = LLMMessageTextView(
|
||||
text: message.text,
|
||||
messageUseMarkdown: true,
|
||||
messageId: message.id,
|
||||
isAssistantMessage: isAssistantMessage,
|
||||
isStreamingMessage: isStreamingMessage
|
||||
)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
.padding(.horizontal, LLMMessageView.horizontalTextPadding)
|
||||
|
||||
HStack(alignment: .lastTextBaseline, spacing: 12) {
|
||||
messageView
|
||||
if !message.attachments.isEmpty {
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
func recordingView(_ recording: Recording) -> some View {
|
||||
HStack {
|
||||
Image(systemName: "mic.fill")
|
||||
.foregroundColor(.blue)
|
||||
Text("Audio Message")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
.padding(.horizontal, LLMMessageView.horizontalTextPadding)
|
||||
.padding(.top, 8)
|
||||
}
|
||||
|
||||
func messageTimeView(needsCapsule: Bool = false) -> some View {
|
||||
Group {
|
||||
if needsCapsule {
|
||||
Text(DateFormatter.timeFormatter.string(from: message.createdAt))
|
||||
.font(.caption)
|
||||
.foregroundColor(.white)
|
||||
.opacity(0.8)
|
||||
.padding(.top, 4)
|
||||
.padding(.bottom, 4)
|
||||
.padding(.horizontal, 8)
|
||||
.background {
|
||||
Capsule()
|
||||
.foregroundColor(.black.opacity(0.4))
|
||||
}
|
||||
} else {
|
||||
Text(DateFormatter.timeFormatter.string(from: message.createdAt))
|
||||
.font(.caption)
|
||||
.foregroundColor(message.user.isCurrentUser ? theme.colors.messageMyTimeText : theme.colors.messageFriendTimeText)
|
||||
}
|
||||
}
|
||||
.sizeGetter($timeSize)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - View Extensions
|
||||
extension View {
|
||||
@ViewBuilder
|
||||
func sizeGetter(_ size: Binding<CGSize>) -> some View {
|
||||
self.background(
|
||||
GeometryReader { geometry in
|
||||
Color.clear
|
||||
.preference(key: SizePreferenceKey.self, value: geometry.size)
|
||||
}
|
||||
)
|
||||
.onPreferenceChange(SizePreferenceKey.self) { newSize in
|
||||
size.wrappedValue = newSize
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
func applyIf<Content: View>(_ condition: Bool, transform: (Self) -> Content) -> some View {
|
||||
if condition {
|
||||
transform(self)
|
||||
} else {
|
||||
self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preference Key
|
||||
struct SizePreferenceKey: PreferenceKey {
|
||||
static var defaultValue: CGSize = .zero
|
||||
static func reduce(value: inout CGSize, nextValue: () -> CGSize) {}
|
||||
}
|
||||
|
||||
// MARK: - Date Formatter Extension
|
||||
extension DateFormatter {
|
||||
static let timeFormatter: DateFormatter = {
|
||||
let formatter = DateFormatter()
|
||||
formatter.timeStyle = .short
|
||||
return formatter
|
||||
}()
|
||||
}
|
||||
|
||||
extension View {
|
||||
@ViewBuilder
|
||||
func bubbleBackground(_ message: Message, theme: ChatTheme, isReply: Bool = false) -> some View {
|
||||
let radius: CGFloat = !message.attachments.isEmpty ? 12 : 20
|
||||
let additionalMediaInset: CGFloat = message.attachments.count > 1 ? 2 : 0
|
||||
self
|
||||
.frame(width: message.attachments.isEmpty ? nil : LLMMessageView.widthWithMedia + additionalMediaInset)
|
||||
.foregroundColor(message.user.isCurrentUser ? theme.colors.messageMyText : theme.colors.messageFriendText)
|
||||
.background {
|
||||
if isReply || !message.text.isEmpty || message.recording != nil {
|
||||
RoundedRectangle(cornerRadius: radius)
|
||||
.foregroundColor(message.user.isCurrentUser ? theme.colors.messageMyBG : theme.colors.messageFriendBG)
|
||||
.opacity(isReply ? 0.5 : 1)
|
||||
}
|
||||
}
|
||||
.cornerRadius(radius)
|
||||
}
|
||||
}
|
|
@ -37,7 +37,7 @@ struct ModelSettingsView: View {
|
|||
var body: some View {
|
||||
NavigationView {
|
||||
Form {
|
||||
Section(header: Text("Model Configuration")) {
|
||||
Section {
|
||||
Toggle("Use mmap", isOn: $viewModel.useMmap)
|
||||
.onChange(of: viewModel.useMmap) { newValue in
|
||||
viewModel.modelConfigManager.updateUseMmap(newValue)
|
||||
|
@ -47,11 +47,13 @@ struct ModelSettingsView: View {
|
|||
viewModel.cleanModelTmpFolder()
|
||||
showAlert = true
|
||||
}
|
||||
} header: {
|
||||
Text("Model Configuration")
|
||||
}
|
||||
|
||||
// Diffusion Settings
|
||||
if viewModel.isDiffusionModel {
|
||||
Section(header: Text("Diffusion Settings")) {
|
||||
Section {
|
||||
Stepper(value: $iterations, in: 1...100) {
|
||||
HStack {
|
||||
Text("Iterations")
|
||||
|
@ -84,9 +86,11 @@ struct ModelSettingsView: View {
|
|||
}
|
||||
}
|
||||
}
|
||||
} header: {
|
||||
Text("Diffusion Settings")
|
||||
}
|
||||
} else {
|
||||
Section(header: Text("Sampling Strategy")) {
|
||||
Section {
|
||||
Picker("Sampler Type", selection: $selectedSampler) {
|
||||
ForEach(SamplerType.allCases, id: \.self) { sampler in
|
||||
Text(sampler.displayName)
|
||||
|
@ -218,6 +222,8 @@ struct ModelSettingsView: View {
|
|||
default:
|
||||
EmptyView()
|
||||
}
|
||||
} header: {
|
||||
Text("Sampling Strategy")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -272,5 +278,3 @@ struct ModelSettingsView: View {
|
|||
viewModel.modelConfigManager.updateMixedSamplers(orderedSelection)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
//
|
||||
// ChatHistoryItemView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/16.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct ChatHistoryItemView: View {
|
||||
let history: ChatHistory
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 12) {
|
||||
|
||||
ModelIconView(modelId: history.modelId)
|
||||
.frame(width: 36, height: 36)
|
||||
.clipShape(Circle())
|
||||
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
|
||||
if let firstMessage = history.messages.last {
|
||||
Text(String(firstMessage.content.prefix(50)) + "...")
|
||||
.lineLimit(2)
|
||||
.font(.system(size: 14))
|
||||
}
|
||||
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text(history.modelName)
|
||||
.font(.system(size: 12))
|
||||
.foregroundColor(.gray)
|
||||
|
||||
Text(history.updatedAt.formatAgo())
|
||||
.font(.system(size: 10))
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
//
|
||||
// SideMenuView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/16.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct SideMenuView: View {
|
||||
@Binding var isOpen: Bool
|
||||
@Binding var selectedHistory: ChatHistory?
|
||||
|
||||
@Binding var histories: [ChatHistory]
|
||||
|
||||
@State private var showingAlert = false
|
||||
@State private var historyToDelete: ChatHistory?
|
||||
|
||||
@State private var dragOffset: CGFloat = 0
|
||||
|
||||
var body: some View {
|
||||
GeometryReader { geometry in
|
||||
VStack {
|
||||
HStack {
|
||||
Text(NSLocalizedString("ChatHistroyTitle", comment: "Chat Histroy Title"))
|
||||
.fontWeight(.medium)
|
||||
.font(.system(size: 20))
|
||||
Spacer()
|
||||
}
|
||||
.padding(.top, 80)
|
||||
.padding(.leading)
|
||||
|
||||
List {
|
||||
ForEach(histories.sorted(by: { $0.updatedAt > $1.updatedAt })) { history in
|
||||
|
||||
ChatHistoryItemView(history: history)
|
||||
.onTapGesture {
|
||||
selectedHistory = history
|
||||
isOpen = false
|
||||
}
|
||||
.onLongPressGesture {
|
||||
historyToDelete = history
|
||||
showingAlert = true
|
||||
}
|
||||
.listRowBackground(Color.sidemenuBg)
|
||||
}
|
||||
}
|
||||
.background(Color.sidemenuBg)
|
||||
.listStyle(PlainListStyle())
|
||||
}
|
||||
.background(Color.sidemenuBg)
|
||||
.frame(width: geometry.size.width * 0.8)
|
||||
.offset(x: isOpen ? 0 : -geometry.size.width * 0.8)
|
||||
.animation(.easeOut, value: isOpen)
|
||||
.gesture(
|
||||
DragGesture()
|
||||
.onChanged { value in
|
||||
if value.translation.width < 0 {
|
||||
dragOffset = value.translation.width
|
||||
}
|
||||
}
|
||||
.onEnded { value in
|
||||
if value.translation.width < -geometry.size.width * 0.25 {
|
||||
isOpen = false
|
||||
}
|
||||
dragOffset = 0
|
||||
}
|
||||
)
|
||||
.alert("Delete History", isPresented: $showingAlert) {
|
||||
Button("Cancel", role: .cancel) {}
|
||||
Button("Delete", role: .destructive) {
|
||||
if let history = historyToDelete {
|
||||
deleteHistory(history)
|
||||
}
|
||||
}
|
||||
} message: {
|
||||
Text("Are you sure you want to delete this history?")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func deleteHistory(_ history: ChatHistory) {
|
||||
// 更新存储
|
||||
ChatHistoryManager.shared.deleteHistory(history)
|
||||
// 从历史记录中删除
|
||||
histories.removeAll { $0.id == history.id }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
//
|
||||
// ChatHistoryItemView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/16.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct ChatHistoryItemView: View {
|
||||
let history: ChatHistory
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
|
||||
if let firstMessage = history.messages.last {
|
||||
Text(String(firstMessage.content.prefix(200)))
|
||||
.lineLimit(1)
|
||||
.font(.system(size: 15, weight: .medium))
|
||||
.foregroundColor(.primary)
|
||||
}
|
||||
|
||||
HStack(alignment: .bottom) {
|
||||
|
||||
ModelIconView(modelId: history.modelId)
|
||||
.frame(width: 20, height: 20)
|
||||
.clipShape(Circle())
|
||||
.padding(.trailing, 0)
|
||||
|
||||
Text(history.modelName)
|
||||
.lineLimit(1)
|
||||
.font(.system(size: 12, weight: .semibold))
|
||||
.foregroundColor(.black.opacity(0.5))
|
||||
|
||||
Spacer()
|
||||
|
||||
Text(history.updatedAt.formatAgo())
|
||||
.font(.system(size: 12, weight: .regular))
|
||||
.foregroundColor(.black.opacity(0.5))
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 10)
|
||||
.padding(.horizontal, 0)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
//
|
||||
// SideMenuView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/16.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct SideMenuView: View {
|
||||
@Binding var isOpen: Bool
|
||||
@Binding var selectedHistory: ChatHistory?
|
||||
@Binding var histories: [ChatHistory]
|
||||
@Binding var navigateToMainSettings: Bool
|
||||
|
||||
@State private var showingAlert = false
|
||||
@State private var historyToDelete: ChatHistory?
|
||||
@State private var navigateToSettings = false
|
||||
|
||||
@State private var dragOffset: CGFloat = 0
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
GeometryReader { geometry in
|
||||
VStack {
|
||||
HStack {
|
||||
Text(NSLocalizedString("ChatHistroyTitle", comment: "Chat Histroy Title"))
|
||||
.fontWeight(.medium)
|
||||
.font(.system(size: 20))
|
||||
Spacer()
|
||||
}
|
||||
.padding(.top, 80)
|
||||
.padding(.leading, 12)
|
||||
|
||||
List {
|
||||
ForEach(histories.sorted(by: { $0.updatedAt > $1.updatedAt })) { history in
|
||||
|
||||
ChatHistoryItemView(history: history)
|
||||
.onTapGesture {
|
||||
selectedHistory = history
|
||||
isOpen = false
|
||||
}
|
||||
.onLongPressGesture {
|
||||
historyToDelete = history
|
||||
showingAlert = true
|
||||
}
|
||||
.listRowBackground(Color.sidemenuBg)
|
||||
.listRowSeparator(.hidden)
|
||||
}
|
||||
}
|
||||
.background(Color.sidemenuBg)
|
||||
.listStyle(PlainListStyle())
|
||||
|
||||
Spacer()
|
||||
|
||||
HStack {
|
||||
Button(action: {
|
||||
isOpen = false
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
|
||||
navigateToMainSettings = true
|
||||
}
|
||||
}) {
|
||||
HStack {
|
||||
Image(systemName: "gear")
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fit)
|
||||
.frame(width: 20, height: 20)
|
||||
}
|
||||
.foregroundColor(.primary)
|
||||
.padding(.leading)
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
.padding(EdgeInsets(top: 10, leading: 12, bottom: 30, trailing: 0))
|
||||
}
|
||||
.background(Color.sidemenuBg)
|
||||
.frame(width: geometry.size.width * 0.8)
|
||||
.offset(x: isOpen ? 0 : -geometry.size.width * 0.8)
|
||||
.animation(.easeOut, value: isOpen)
|
||||
.gesture(
|
||||
DragGesture()
|
||||
.onChanged { value in
|
||||
if value.translation.width < 0 {
|
||||
dragOffset = value.translation.width
|
||||
}
|
||||
}
|
||||
.onEnded { value in
|
||||
if value.translation.width < -geometry.size.width * 0.25 {
|
||||
isOpen = false
|
||||
}
|
||||
dragOffset = 0
|
||||
}
|
||||
)
|
||||
.alert("Delete History", isPresented: $showingAlert) {
|
||||
Button("Cancel", role: .cancel) {}
|
||||
Button(LocalizedStringKey("button.delete"), role: .destructive) {
|
||||
if let history = historyToDelete {
|
||||
deleteHistory(history)
|
||||
}
|
||||
}
|
||||
} message: {
|
||||
Text("Are you sure you want to delete this history?")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func deleteHistory(_ history: ChatHistory) {
|
||||
ChatHistoryManager.shared.deleteHistory(history)
|
||||
histories.removeAll { $0.id == history.id }
|
||||
}
|
||||
}
|
||||
|
||||
struct SettingsFullScreenView: View {
|
||||
@Binding var isPresented: Bool
|
||||
|
||||
var body: some View {
|
||||
NavigationView {
|
||||
SettingsView()
|
||||
.navigationTitle("Settings")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .navigationBarLeading) {
|
||||
Button(action: {
|
||||
isPresented = false
|
||||
}) {
|
||||
Image(systemName: "xmark")
|
||||
.foregroundColor(.primary)
|
||||
.fontWeight(.medium)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.navigationViewStyle(StackNavigationViewStyle())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
//
|
||||
// LLMInferenceEngineWrapper.h
|
||||
// mnn-llm
|
||||
//
|
||||
// Created by wangzhaode on 2023/12/14.
|
||||
//
|
||||
|
||||
#ifndef LLMInferenceEngineWrapper_h
|
||||
#define LLMInferenceEngineWrapper_h
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
typedef void (^CompletionHandler)(BOOL success);
|
||||
typedef void (^OutputHandler)(NSString * _Nonnull output);
|
||||
|
||||
// MARK: - Benchmark Related Types
|
||||
|
||||
/**
|
||||
* Progress type enumeration for structured benchmark reporting
|
||||
*/
|
||||
typedef NS_ENUM(NSInteger, BenchmarkProgressType) {
|
||||
BenchmarkProgressTypeUnknown = 0,
|
||||
BenchmarkProgressTypeInitializing = 1,
|
||||
BenchmarkProgressTypeWarmingUp = 2,
|
||||
BenchmarkProgressTypeRunningTest = 3,
|
||||
BenchmarkProgressTypeProcessingResults = 4,
|
||||
BenchmarkProgressTypeCompleted = 5,
|
||||
BenchmarkProgressTypeStopping = 6
|
||||
};
|
||||
|
||||
/**
|
||||
* Structured progress information for benchmark
|
||||
*/
|
||||
@interface BenchmarkProgressInfo : NSObject
|
||||
|
||||
@property (nonatomic, assign) NSInteger progress; // 0-100
|
||||
@property (nonatomic, strong) NSString *statusMessage; // Status description
|
||||
@property (nonatomic, assign) BenchmarkProgressType progressType;
|
||||
@property (nonatomic, assign) NSInteger currentIteration;
|
||||
@property (nonatomic, assign) NSInteger totalIterations;
|
||||
@property (nonatomic, assign) NSInteger nPrompt;
|
||||
@property (nonatomic, assign) NSInteger nGenerate;
|
||||
@property (nonatomic, assign) float runTimeSeconds;
|
||||
@property (nonatomic, assign) float prefillTimeSeconds;
|
||||
@property (nonatomic, assign) float decodeTimeSeconds;
|
||||
@property (nonatomic, assign) float prefillSpeed;
|
||||
@property (nonatomic, assign) float decodeSpeed;
|
||||
|
||||
@end
|
||||
|
||||
/**
|
||||
* Benchmark result structure
|
||||
*/
|
||||
@interface BenchmarkResult : NSObject
|
||||
|
||||
@property (nonatomic, assign) BOOL success;
|
||||
@property (nonatomic, strong, nullable) NSString *errorMessage;
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *prefillTimesUs;
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *decodeTimesUs;
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *sampleTimesUs;
|
||||
@property (nonatomic, assign) NSInteger promptTokens;
|
||||
@property (nonatomic, assign) NSInteger generateTokens;
|
||||
@property (nonatomic, assign) NSInteger repeatCount;
|
||||
@property (nonatomic, assign) BOOL kvCacheEnabled;
|
||||
|
||||
@end
|
||||
|
||||
// Benchmark callback blocks
|
||||
typedef void (^BenchmarkProgressCallback)(BenchmarkProgressInfo *progressInfo);
|
||||
typedef void (^BenchmarkErrorCallback)(NSString *error);
|
||||
typedef void (^BenchmarkIterationCompleteCallback)(NSString *detailedStats);
|
||||
typedef void (^BenchmarkCompleteCallback)(BenchmarkResult *result);
|
||||
|
||||
/**
|
||||
* LLMInferenceEngineWrapper - A high-level Objective-C wrapper for MNN LLM inference engine
|
||||
*
|
||||
* This class provides a convenient interface for integrating MNN's Large Language Model
|
||||
* inference capabilities into iOS applications with enhanced error handling, performance
|
||||
* optimization, and thread safety.
|
||||
*/
|
||||
@interface LLMInferenceEngineWrapper : NSObject
|
||||
|
||||
/**
|
||||
* Initialize the LLM inference engine with a model path
|
||||
*
|
||||
* @param modelPath The file system path to the model directory
|
||||
* @param completion Completion handler called with success/failure status
|
||||
* @return Initialized instance of LLMInferenceEngineWrapper
|
||||
*/
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath completion:(CompletionHandler)completion;
|
||||
|
||||
/**
|
||||
* Process user input and generate streaming LLM response
|
||||
*
|
||||
* @param input The user's input text to process
|
||||
* @param output Callback block that receives streaming output chunks
|
||||
*/
|
||||
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output;
|
||||
|
||||
/**
|
||||
* Process user input and generate streaming LLM response with optional performance output
|
||||
*
|
||||
* @param input The user's input text to process
|
||||
* @param output Callback block that receives streaming output chunks
|
||||
* @param showPerformance Whether to output performance statistics after response completion
|
||||
*/
|
||||
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output showPerformance:(BOOL)showPerformance;
|
||||
|
||||
/**
|
||||
* Add chat prompts from an array of dictionaries to the conversation history
|
||||
*
|
||||
* @param array NSArray containing NSDictionary objects with chat messages
|
||||
*/
|
||||
- (void)addPromptsFromArray:(NSArray<NSDictionary *> *)array;
|
||||
|
||||
/**
|
||||
* Set the configuration for the LLM engine using a JSON string
|
||||
*
|
||||
* @param jsonStr JSON string containing configuration parameters
|
||||
*/
|
||||
- (void)setConfigWithJSONString:(NSString *)jsonStr;
|
||||
|
||||
/**
|
||||
* Check if model is ready for inference
|
||||
*
|
||||
* @return YES if model is loaded and ready
|
||||
*/
|
||||
- (BOOL)isModelReady;
|
||||
|
||||
/**
|
||||
* Get current processing status
|
||||
*
|
||||
* @return YES if currently processing an inference request
|
||||
*/
|
||||
- (BOOL)isProcessing;
|
||||
|
||||
/**
|
||||
* Cancel ongoing inference (if supported)
|
||||
*/
|
||||
- (void)cancelInference;
|
||||
|
||||
/**
|
||||
* Get chat history count
|
||||
*
|
||||
* @return Number of messages in chat history
|
||||
*/
|
||||
- (NSUInteger)getChatHistoryCount;
|
||||
|
||||
/**
|
||||
* Clear chat history
|
||||
*/
|
||||
- (void)clearChatHistory;
|
||||
|
||||
// MARK: - Benchmark Methods
|
||||
|
||||
/**
|
||||
* Run official benchmark following llm_bench.cpp approach
|
||||
*
|
||||
* @param backend Backend type (0 for CPU)
|
||||
* @param threads Number of threads
|
||||
* @param useMmap Whether to use memory mapping
|
||||
* @param power Power setting
|
||||
* @param precision Precision setting (2 for low precision)
|
||||
* @param memory Memory setting (2 for low memory)
|
||||
* @param dynamicOption Dynamic optimization option
|
||||
* @param nPrompt Number of prompt tokens
|
||||
* @param nGenerate Number of tokens to generate
|
||||
* @param nRepeat Number of repetitions
|
||||
* @param kvCache Whether to use KV cache
|
||||
* @param progressCallback Progress update callback
|
||||
* @param errorCallback Error callback
|
||||
* @param iterationCompleteCallback Iteration completion callback
|
||||
* @param completeCallback Final completion callback
|
||||
*/
|
||||
- (void)runOfficialBenchmarkWithBackend:(NSInteger)backend
|
||||
threads:(NSInteger)threads
|
||||
useMmap:(BOOL)useMmap
|
||||
power:(NSInteger)power
|
||||
precision:(NSInteger)precision
|
||||
memory:(NSInteger)memory
|
||||
dynamicOption:(NSInteger)dynamicOption
|
||||
nPrompt:(NSInteger)nPrompt
|
||||
nGenerate:(NSInteger)nGenerate
|
||||
nRepeat:(NSInteger)nRepeat
|
||||
kvCache:(BOOL)kvCache
|
||||
progressCallback:(BenchmarkProgressCallback _Nullable)progressCallback
|
||||
errorCallback:(BenchmarkErrorCallback _Nullable)errorCallback
|
||||
iterationCompleteCallback:(BenchmarkIterationCompleteCallback _Nullable)iterationCompleteCallback
|
||||
completeCallback:(BenchmarkCompleteCallback _Nullable)completeCallback;
|
||||
|
||||
/**
|
||||
* Stop running benchmark
|
||||
*/
|
||||
- (void)stopBenchmark;
|
||||
|
||||
/**
|
||||
* Check if benchmark is currently running
|
||||
*
|
||||
* @return YES if benchmark is running
|
||||
*/
|
||||
- (BOOL)isBenchmarkRunning;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif /* LLMInferenceEngineWrapper_h */
|
File diff suppressed because it is too large
Load Diff
|
@ -1,33 +0,0 @@
|
|||
//
|
||||
// LLMInferenceEngineWrapper.h
|
||||
// mnn-llm
|
||||
//
|
||||
// Created by wangzhaode on 2023/12/14.
|
||||
//
|
||||
|
||||
#ifndef LLMInferenceEngineWrapper_h
|
||||
#define LLMInferenceEngineWrapper_h
|
||||
|
||||
|
||||
// LLMInferenceEngineWrapper.h
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
typedef void (^CompletionHandler)(BOOL success);
|
||||
typedef void (^OutputHandler)(NSString * _Nonnull output);
|
||||
|
||||
@interface LLMInferenceEngineWrapper : NSObject
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath completion:(CompletionHandler)completion;
|
||||
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output;
|
||||
|
||||
- (void)addPromptsFromArray:(NSArray<NSDictionary *> *)array;
|
||||
|
||||
- (void)setConfigWithJSONString:(NSString *)jsonStr;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif /* LLMInferenceEngineWrapper_h */
|
|
@ -1,274 +0,0 @@
|
|||
//
|
||||
// LLMInferenceEngineWrapper.m
|
||||
// mnn-llm
|
||||
//
|
||||
// Created by wangzhaode on 2023/12/14.
|
||||
//
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <MNN/llm/llm.hpp>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "LLMInferenceEngineWrapper.h"
|
||||
|
||||
using namespace MNN::Transformer;
|
||||
|
||||
using ChatMessage = std::pair<std::string, std::string>;
|
||||
static std::vector<ChatMessage> history{};
|
||||
|
||||
@implementation LLMInferenceEngineWrapper {
|
||||
std::shared_ptr<Llm> llm;
|
||||
}
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath completion:(CompletionHandler)completion {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
|
||||
BOOL success = [self loadModelFromPath:modelPath];
|
||||
// MARK: Test Local Model
|
||||
// BOOL success = [self loadModel];
|
||||
|
||||
dispatch_async(dispatch_get_main_queue(), ^{
|
||||
completion(success);
|
||||
});
|
||||
});
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
bool remove_directory(const std::string& path) {
|
||||
try {
|
||||
std::filesystem::remove_all(path); // 删除目录及其内容
|
||||
return true;
|
||||
} catch (const std::filesystem::filesystem_error& e) {
|
||||
std::cerr << "Error removing directory: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
- (BOOL)loadModel {
|
||||
if (!llm) {
|
||||
NSString *bundleDirectory = [[NSBundle mainBundle] bundlePath];
|
||||
std::string model_dir = [bundleDirectory UTF8String];
|
||||
std::string config_path = model_dir + "/config.json";
|
||||
llm.reset(Llm::createLLM(config_path));
|
||||
NSString *tempDirectory = NSTemporaryDirectory();
|
||||
llm->set_config("{\"tmp_path\":\"" + std::string([tempDirectory UTF8String]) + "\", \"use_mmap\":true}");
|
||||
llm->load();
|
||||
}
|
||||
return YES;
|
||||
}
|
||||
|
||||
- (BOOL)loadModelFromPath:(NSString *)modelPath {
|
||||
if (!llm) {
|
||||
std::string config_path = std::string([modelPath UTF8String]) + "/config.json";
|
||||
|
||||
// Read the config file to get use_mmap value
|
||||
NSError *error = nil;
|
||||
NSData *configData = [NSData dataWithContentsOfFile:[NSString stringWithUTF8String:config_path.c_str()]];
|
||||
NSDictionary *configDict = [NSJSONSerialization JSONObjectWithData:configData options:0 error:&error];
|
||||
// If use_mmap key doesn't exist, default to YES
|
||||
BOOL useMmap = configDict[@"use_mmap"] == nil ? YES : [configDict[@"use_mmap"] boolValue];
|
||||
|
||||
llm.reset(Llm::createLLM(config_path));
|
||||
if (!llm) {
|
||||
return NO;
|
||||
}
|
||||
|
||||
// Create temp directory inside the modelPath folder
|
||||
std::string model_path_str([modelPath UTF8String]);
|
||||
std::string temp_directory_path = model_path_str + "/temp";
|
||||
|
||||
struct stat info;
|
||||
if (stat(temp_directory_path.c_str(), &info) == 0) {
|
||||
// Directory exists, so remove it
|
||||
if (!remove_directory(temp_directory_path)) {
|
||||
std::cerr << "Failed to remove existing temp directory: " << temp_directory_path << std::endl;
|
||||
return NO;
|
||||
}
|
||||
std::cerr << "Existing temp directory removed: " << temp_directory_path << std::endl;
|
||||
}
|
||||
|
||||
// Now create the temp directory
|
||||
if (mkdir(temp_directory_path.c_str(), 0777) != 0) {
|
||||
std::cerr << "Failed to create temp directory: " << temp_directory_path << std::endl;
|
||||
return NO;
|
||||
}
|
||||
std::cerr << "Temp directory created: " << temp_directory_path << std::endl;
|
||||
|
||||
// NSLog(@"useMmap value: %@", useMmap ? @"YES" : @"NO");
|
||||
|
||||
// Explicitly convert BOOL to bool and ensure proper string conversion
|
||||
bool useMmapCpp = (useMmap == YES);
|
||||
std::string configStr = "{\"tmp_path\":\"" + temp_directory_path + "\", \"use_mmap\":" + (useMmapCpp ? "true" : "false") + "}";
|
||||
// Debug print to check the final config string
|
||||
// NSLog(@"Config string: %s", configStr.c_str());
|
||||
|
||||
llm->set_config(configStr);
|
||||
|
||||
llm->load();
|
||||
}
|
||||
else {
|
||||
std::cerr << "Warmming:: LLM have already been created!" << std::endl;
|
||||
}
|
||||
return YES;
|
||||
}
|
||||
|
||||
- (void)setConfigWithJSONString:(NSString *)jsonStr {
|
||||
|
||||
if (!llm) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (jsonStr) {
|
||||
const char *cString = [jsonStr UTF8String];
|
||||
std::string stdString(cString);
|
||||
|
||||
llm->set_config(stdString);
|
||||
} else {
|
||||
NSLog(@"Error: JSON string is nil or invalid.");
|
||||
}
|
||||
}
|
||||
|
||||
// llm stream buffer with callback
|
||||
class LlmStreamBuffer : public std::streambuf {
|
||||
public:
|
||||
using CallBack = std::function<void(const char* str, size_t len)>;
|
||||
LlmStreamBuffer(CallBack callback) : callback_(callback) {}
|
||||
|
||||
protected:
|
||||
virtual std::streamsize xsputn(const char* s, std::streamsize n) override {
|
||||
if (callback_) {
|
||||
callback_(s, n);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
private:
|
||||
CallBack callback_ = nullptr;
|
||||
};
|
||||
|
||||
- (void)processInput:(NSString *)input withOutput:(OutputHandler)output {
|
||||
if (llm == nil) {
|
||||
output(@"Error: Model not loaded");
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_LOW, 0), ^{
|
||||
|
||||
LlmStreamBuffer::CallBack callback = [output](const char* str, size_t len) {
|
||||
if (output) {
|
||||
NSString *nsOutput = [[NSString alloc] initWithBytes:str
|
||||
length:len
|
||||
encoding:NSUTF8StringEncoding];
|
||||
if (nsOutput) {
|
||||
output(nsOutput);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
LlmStreamBuffer streambuf(callback);
|
||||
std::ostream os(&streambuf);
|
||||
|
||||
history.emplace_back(ChatMessage("user", [input UTF8String]));
|
||||
|
||||
if (std::string([input UTF8String]) == "benchmark") {
|
||||
[self performBenchmarkWithOutput:&os];
|
||||
} else {
|
||||
llm->response(history, &os, "<eop>", 999999);
|
||||
}
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
// New method to handle benchmarking
|
||||
- (void)performBenchmarkWithOutput:(std::ostream *)os {
|
||||
std::string model_dir = [[[NSBundle mainBundle] bundlePath] UTF8String];
|
||||
std::string prompt_file = model_dir + "/bench.txt";
|
||||
std::ifstream prompt_fs(prompt_file);
|
||||
std::vector<std::string> prompts;
|
||||
std::string prompt;
|
||||
while (std::getline(prompt_fs, prompt)) {
|
||||
if (prompt.substr(0, 1) == "#") {
|
||||
continue;
|
||||
}
|
||||
std::string::size_type pos = 0;
|
||||
while ((pos = prompt.find("\\n", pos)) != std::string::npos) {
|
||||
prompt.replace(pos, 2, "\n");
|
||||
pos += 1;
|
||||
}
|
||||
prompts.push_back(prompt);
|
||||
}
|
||||
|
||||
int prompt_len = 0;
|
||||
int decode_len = 0;
|
||||
int64_t prefill_time = 0;
|
||||
int64_t decode_time = 0;
|
||||
|
||||
auto context = llm->getContext();
|
||||
for (const auto& p : prompts) {
|
||||
llm->response(p, os, "\n");
|
||||
prompt_len += context->prompt_len;
|
||||
decode_len += context->gen_seq_len;
|
||||
prefill_time += context->prefill_us;
|
||||
decode_time += context->decode_us;
|
||||
}
|
||||
|
||||
float prefill_s = prefill_time / 1e6;
|
||||
float decode_s = decode_time / 1e6;
|
||||
|
||||
*os << "\n#################################\n"
|
||||
<< "prompt tokens num = " << prompt_len << "\n"
|
||||
<< "decode tokens num = " << decode_len << "\n"
|
||||
<< "prefill time = " << std::fixed << std::setprecision(2) << prefill_s << " s\n"
|
||||
<< "decode time = " << std::fixed << std::setprecision(2) << decode_s << " s\n"
|
||||
<< "prefill speed = " << std::fixed << std::setprecision(2) << (prefill_s > 0 ? prompt_len / prefill_s : 0) << " tok/s\n"
|
||||
<< "decode speed = " << std::fixed << std::setprecision(2) << (decode_s > 0 ? decode_len / decode_s : 0) << " tok/s\n"
|
||||
<< "##################################\n";
|
||||
*os << "<eop>";
|
||||
}
|
||||
|
||||
- (void)dealloc {
|
||||
std::cerr << "llm dealloc reset" << std::endl;
|
||||
history.clear();
|
||||
llm.reset();
|
||||
llm = nil;
|
||||
}
|
||||
|
||||
- (void)init:(const std::vector<std::string>&)chatHistory {
|
||||
history.clear();
|
||||
history.emplace_back("system", "You are a helpful assistant.");
|
||||
|
||||
for (size_t i = 0; i < chatHistory.size(); ++i) {
|
||||
history.emplace_back(i % 2 == 0 ? "user" : "assistant", chatHistory[i]);
|
||||
}
|
||||
}
|
||||
|
||||
- (void)addPromptsFromArray:(NSArray<NSDictionary *> *)array {
|
||||
|
||||
history.clear();
|
||||
|
||||
for (NSDictionary *dict in array) {
|
||||
[self addPromptsFromDictionary:dict];
|
||||
}
|
||||
}
|
||||
|
||||
- (void)addPromptsFromDictionary:(NSDictionary *)dictionary {
|
||||
for (NSString *key in dictionary) {
|
||||
NSString *value = dictionary[key];
|
||||
|
||||
std::string keyString = [key UTF8String];
|
||||
std::string valueString = [value UTF8String];
|
||||
|
||||
history.emplace_back(ChatMessage(keyString, valueString));
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
|
@ -12,11 +12,15 @@ struct MNNLLMiOSApp: App {
|
|||
|
||||
init() {
|
||||
UIView.appearance().overrideUserInterfaceStyle = .light
|
||||
|
||||
let savedLanguage = LanguageManager.shared.currentLanguage
|
||||
UserDefaults.standard.set([savedLanguage], forKey: "AppleLanguages")
|
||||
UserDefaults.standard.synchronize()
|
||||
}
|
||||
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
ModelListView()
|
||||
MainTabView()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
//
|
||||
// BenchmarkResultsHelper.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import Darwin
|
||||
|
||||
/**
|
||||
* Helper class for processing and formatting benchmark test results.
|
||||
* Provides statistical analysis, formatting utilities, and device information
|
||||
* for benchmark result display and sharing.
|
||||
*/
|
||||
class BenchmarkResultsHelper {
|
||||
static let shared = BenchmarkResultsHelper()
|
||||
|
||||
private init() {}
|
||||
|
||||
// MARK: - Results Processing & Statistics
|
||||
|
||||
/// Processes test results to generate comprehensive benchmark statistics
|
||||
/// - Parameter testResults: Array of completed test instances
|
||||
/// - Returns: Processed statistics including speed metrics and configuration details
|
||||
func processTestResults(_ testResults: [TestInstance]) -> BenchmarkStatistics {
|
||||
guard !testResults.isEmpty else {
|
||||
return BenchmarkStatistics.empty
|
||||
}
|
||||
|
||||
let firstTest = testResults[0]
|
||||
let configText = "Backend: CPU, Threads: \(firstTest.threads), Memory: Low, Precision: Low"
|
||||
|
||||
var prefillStats: SpeedStatistics?
|
||||
var decodeStats: SpeedStatistics?
|
||||
var totalTokensProcessed = 0
|
||||
|
||||
// Calculate prefill (prompt processing) statistics
|
||||
let allPrefillSpeeds = testResults.flatMap { test in
|
||||
test.getTokensPerSecond(tokens: test.nPrompt, timesUs: test.prefillUs)
|
||||
}
|
||||
|
||||
if !allPrefillSpeeds.isEmpty {
|
||||
let avgPrefill = allPrefillSpeeds.reduce(0, +) / Double(allPrefillSpeeds.count)
|
||||
let stdevPrefill = calculateStandardDeviation(values: allPrefillSpeeds, mean: avgPrefill)
|
||||
prefillStats = SpeedStatistics(average: avgPrefill, stdev: stdevPrefill, label: "Prompt Processing")
|
||||
}
|
||||
|
||||
// Calculate decode (token generation) statistics
|
||||
let allDecodeSpeeds = testResults.flatMap { test in
|
||||
test.getTokensPerSecond(tokens: test.nGenerate, timesUs: test.decodeUs)
|
||||
}
|
||||
|
||||
if !allDecodeSpeeds.isEmpty {
|
||||
let avgDecode = allDecodeSpeeds.reduce(0, +) / Double(allDecodeSpeeds.count)
|
||||
let stdevDecode = calculateStandardDeviation(values: allDecodeSpeeds, mean: avgDecode)
|
||||
decodeStats = SpeedStatistics(average: avgDecode, stdev: stdevDecode, label: "Token Generation")
|
||||
}
|
||||
|
||||
// Calculate total tokens processed across all tests
|
||||
totalTokensProcessed = testResults.reduce(0) { sum, test in
|
||||
return sum + (test.nPrompt * test.prefillUs.count) + (test.nGenerate * test.decodeUs.count)
|
||||
}
|
||||
|
||||
return BenchmarkStatistics(
|
||||
configText: configText,
|
||||
prefillStats: prefillStats,
|
||||
decodeStats: decodeStats,
|
||||
totalTokensProcessed: totalTokensProcessed,
|
||||
totalTests: testResults.count
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculates standard deviation for a set of values
|
||||
/// - Parameters:
|
||||
/// - values: Array of numeric values
|
||||
/// - mean: Pre-calculated mean of the values
|
||||
/// - Returns: Standard deviation value
|
||||
private func calculateStandardDeviation(values: [Double], mean: Double) -> Double {
|
||||
guard values.count > 1 else { return 0.0 }
|
||||
|
||||
let variance = values.reduce(0) { sum, value in
|
||||
let diff = value - mean
|
||||
return sum + (diff * diff)
|
||||
} / Double(values.count - 1)
|
||||
|
||||
return sqrt(variance)
|
||||
}
|
||||
|
||||
// MARK: - Formatting & Display
|
||||
|
||||
/// Formats speed statistics with average and standard deviation
|
||||
/// - Parameter stats: Speed statistics to format
|
||||
/// - Returns: Formatted string like "42.5 ± 3.2 tok/s"
|
||||
func formatSpeedStatisticsLine(_ stats: SpeedStatistics) -> String {
|
||||
return String(format: "%.1f ± %.1f tok/s", stats.average, stats.stdev)
|
||||
}
|
||||
|
||||
/// Returns the label-only portion of speed statistics
|
||||
/// - Parameter stats: Speed statistics object
|
||||
/// - Returns: Human-readable label (e.g., "Prompt Processing")
|
||||
func formatSpeedLabelOnly(_ stats: SpeedStatistics) -> String {
|
||||
return stats.label
|
||||
}
|
||||
|
||||
/// Formats model parameter summary for display
|
||||
/// - Parameters:
|
||||
/// - totalTokens: Total number of tokens processed
|
||||
/// - totalTests: Total number of tests completed
|
||||
/// - Returns: Formatted summary string
|
||||
func formatModelParams(totalTokens: Int, totalTests: Int) -> String {
|
||||
return "Total Tokens: \(totalTokens), Tests: \(totalTests)"
|
||||
}
|
||||
|
||||
/// Formats memory usage with percentage and absolute values
|
||||
/// - Parameters:
|
||||
/// - maxMemoryKb: Peak memory usage in kilobytes
|
||||
/// - totalKb: Total system memory in kilobytes
|
||||
/// - Returns: Tuple containing formatted value and percentage label
|
||||
func formatMemoryUsage(maxMemoryKb: Int64, totalKb: Int64) -> (valueText: String, labelText: String) {
|
||||
let maxMemoryMB = Double(maxMemoryKb) / 1024.0
|
||||
let totalMemoryGB = Double(totalKb) / (1024.0 * 1024.0)
|
||||
let percentage = (Double(maxMemoryKb) / Double(totalKb)) * 100.0
|
||||
|
||||
let valueText = String(format: "%.1f MB", maxMemoryMB)
|
||||
let labelText = String(format: "%.1f%% of %.1f GB", percentage, totalMemoryGB)
|
||||
|
||||
return (valueText, labelText)
|
||||
}
|
||||
|
||||
// MARK: - Device & System Information
|
||||
|
||||
/// Gets comprehensive device information including model and iOS version
|
||||
/// - Returns: Formatted device info string (e.g., "iPhone 14 Pro, iOS 17.0")
|
||||
func getDeviceInfo() -> String {
|
||||
return DeviceInfoHelper.shared.getDeviceInfo()
|
||||
}
|
||||
|
||||
/// Gets total system memory in kilobytes
|
||||
/// - Returns: System memory size in KB
|
||||
func getTotalSystemMemoryKb() -> Int64 {
|
||||
return Int64(ProcessInfo.processInfo.physicalMemory) / 1024
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
//
|
||||
// BenchmarkErrorCode.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Enumeration of possible error codes that can occur during benchmark execution.
|
||||
* Provides specific error identification for different failure scenarios.
|
||||
*/
|
||||
enum BenchmarkErrorCode: Int {
|
||||
case benchmarkFailedUnknown = 30
|
||||
case testInstanceFailed = 40
|
||||
case modelNotInitialized = 50
|
||||
case benchmarkRunning = 99
|
||||
case benchmarkStopped = 100
|
||||
case nativeError = 0
|
||||
case modelError = 2
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
//
|
||||
// BenchmarkProgress.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Structure containing detailed progress information for benchmark execution.
|
||||
* Provides real-time metrics including timing data and performance statistics.
|
||||
*/
|
||||
struct BenchmarkProgress {
|
||||
let progress: Int // 0-100
|
||||
let statusMessage: String
|
||||
let progressType: ProgressType
|
||||
let currentIteration: Int
|
||||
let totalIterations: Int
|
||||
let nPrompt: Int
|
||||
let nGenerate: Int
|
||||
let runTimeSeconds: Float
|
||||
let prefillTimeSeconds: Float
|
||||
let decodeTimeSeconds: Float
|
||||
let prefillSpeed: Float
|
||||
let decodeSpeed: Float
|
||||
|
||||
init(progress: Int,
|
||||
statusMessage: String,
|
||||
progressType: ProgressType = .unknown,
|
||||
currentIteration: Int = 0,
|
||||
totalIterations: Int = 0,
|
||||
nPrompt: Int = 0,
|
||||
nGenerate: Int = 0,
|
||||
runTimeSeconds: Float = 0.0,
|
||||
prefillTimeSeconds: Float = 0.0,
|
||||
decodeTimeSeconds: Float = 0.0,
|
||||
prefillSpeed: Float = 0.0,
|
||||
decodeSpeed: Float = 0.0) {
|
||||
self.progress = progress
|
||||
self.statusMessage = statusMessage
|
||||
self.progressType = progressType
|
||||
self.currentIteration = currentIteration
|
||||
self.totalIterations = totalIterations
|
||||
self.nPrompt = nPrompt
|
||||
self.nGenerate = nGenerate
|
||||
self.runTimeSeconds = runTimeSeconds
|
||||
self.prefillTimeSeconds = prefillTimeSeconds
|
||||
self.decodeTimeSeconds = decodeTimeSeconds
|
||||
self.prefillSpeed = prefillSpeed
|
||||
self.decodeSpeed = decodeSpeed
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
//
|
||||
// BenchmarkResult.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Structure containing the results of a completed benchmark test.
|
||||
* Encapsulates test instance data along with success status and error information.
|
||||
*/
|
||||
struct BenchmarkResult {
|
||||
let testInstance: TestInstance
|
||||
let success: Bool
|
||||
let errorMessage: String?
|
||||
|
||||
init(testInstance: TestInstance, success: Bool, errorMessage: String? = nil) {
|
||||
self.testInstance = testInstance
|
||||
self.success = success
|
||||
self.errorMessage = errorMessage
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
//
|
||||
// BenchmarkResults.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Structure containing comprehensive benchmark results for display and sharing.
|
||||
* Aggregates test results, memory usage, and metadata for result presentation.
|
||||
*/
|
||||
struct BenchmarkResults {
|
||||
let modelDisplayName: String
|
||||
let maxMemoryKb: Int64
|
||||
let testResults: [TestInstance]
|
||||
let timestamp: String
|
||||
|
||||
init(modelDisplayName: String, maxMemoryKb: Int64, testResults: [TestInstance], timestamp: String) {
|
||||
self.modelDisplayName = modelDisplayName
|
||||
self.maxMemoryKb = maxMemoryKb
|
||||
self.testResults = testResults
|
||||
self.timestamp = timestamp
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
//
|
||||
// BenchmarkStatistics.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Structure containing comprehensive statistical analysis of benchmark results.
|
||||
* Aggregates performance metrics, configuration details, and test summary information.
|
||||
*/
|
||||
struct BenchmarkStatistics {
|
||||
let configText: String
|
||||
let prefillStats: SpeedStatistics?
|
||||
let decodeStats: SpeedStatistics?
|
||||
let totalTokensProcessed: Int
|
||||
let totalTests: Int
|
||||
|
||||
static let empty = BenchmarkStatistics(
|
||||
configText: "",
|
||||
prefillStats: nil,
|
||||
decodeStats: nil,
|
||||
totalTokensProcessed: 0,
|
||||
totalTests: 0
|
||||
)
|
||||
}
|
|
@ -0,0 +1,142 @@
|
|||
//
|
||||
// DeviceInfoHelper.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import UIKit
|
||||
|
||||
/**
|
||||
* Helper class for retrieving device information including model identification
|
||||
* and system details. Provides device-specific information for benchmark results.
|
||||
*/
|
||||
class DeviceInfoHelper {
|
||||
static let shared = DeviceInfoHelper()
|
||||
|
||||
private init() {}
|
||||
|
||||
/// Gets the device model identifier (e.g., "iPhone14,7")
|
||||
func getDeviceIdentifier() -> String {
|
||||
var systemInfo = utsname()
|
||||
uname(&systemInfo)
|
||||
|
||||
let machineMirror = Mirror(reflecting: systemInfo.machine)
|
||||
let identifier = machineMirror.children.reduce("") { identifier, element in
|
||||
guard let value = element.value as? Int8, value != 0 else { return identifier }
|
||||
return identifier + String(UnicodeScalar(UInt8(value)))
|
||||
}
|
||||
|
||||
return identifier
|
||||
}
|
||||
|
||||
/// Gets the user-friendly device name (e.g., "iPhone 13 mini")
|
||||
func getDeviceModelName() -> String {
|
||||
let identifier = getDeviceIdentifier()
|
||||
return mapIdentifierToModelName(identifier)
|
||||
}
|
||||
|
||||
/// Gets detailed device information including model and system version
|
||||
func getDeviceInfo() -> String {
|
||||
let device = UIDevice.current
|
||||
let systemVersion = device.systemVersion
|
||||
let modelName = getDeviceModelName()
|
||||
return "\(modelName), iOS \(systemVersion)"
|
||||
}
|
||||
|
||||
private func mapIdentifierToModelName(_ identifier: String) -> String {
|
||||
// iPhone mappings
|
||||
let iPhoneMappings: [String: String] = [
|
||||
// iPhone 13 series
|
||||
"iPhone14,4": "iPhone 13 mini",
|
||||
"iPhone14,5": "iPhone 13",
|
||||
"iPhone14,2": "iPhone 13 Pro",
|
||||
"iPhone14,3": "iPhone 13 Pro Max",
|
||||
|
||||
// iPhone 14 series
|
||||
"iPhone14,7": "iPhone 14",
|
||||
"iPhone14,8": "iPhone 14 Plus",
|
||||
"iPhone15,2": "iPhone 14 Pro",
|
||||
"iPhone15,3": "iPhone 14 Pro Max",
|
||||
|
||||
// iPhone 15 series
|
||||
"iPhone15,4": "iPhone 15",
|
||||
"iPhone15,5": "iPhone 15 Plus",
|
||||
"iPhone16,1": "iPhone 15 Pro",
|
||||
"iPhone16,2": "iPhone 15 Pro Max",
|
||||
|
||||
// iPhone 16 series
|
||||
"iPhone17,1": "iPhone 16",
|
||||
"iPhone17,2": "iPhone 16 Plus",
|
||||
"iPhone17,3": "iPhone 16 Pro",
|
||||
"iPhone17,4": "iPhone 16 Pro Max",
|
||||
|
||||
// iPhone SE series
|
||||
"iPhone12,8": "iPhone SE (2nd generation)",
|
||||
"iPhone14,6": "iPhone SE (3rd generation)",
|
||||
|
||||
// Older iPhones
|
||||
"iPhone13,1": "iPhone 12 mini",
|
||||
"iPhone13,2": "iPhone 12",
|
||||
"iPhone13,3": "iPhone 12 Pro",
|
||||
"iPhone13,4": "iPhone 12 Pro Max",
|
||||
"iPhone12,1": "iPhone 11",
|
||||
"iPhone12,3": "iPhone 11 Pro",
|
||||
"iPhone12,5": "iPhone 11 Pro Max",
|
||||
]
|
||||
|
||||
// iPad mappings
|
||||
let iPadMappings: [String: String] = [
|
||||
// iPad Pro 12.9-inch
|
||||
"iPad13,8": "iPad Pro (12.9-inch) (5th generation)",
|
||||
"iPad13,9": "iPad Pro (12.9-inch) (5th generation)",
|
||||
"iPad13,10": "iPad Pro (12.9-inch) (5th generation)",
|
||||
"iPad13,11": "iPad Pro (12.9-inch) (5th generation)",
|
||||
"iPad14,5": "iPad Pro (12.9-inch) (6th generation)",
|
||||
"iPad14,6": "iPad Pro (12.9-inch) (6th generation)",
|
||||
|
||||
// iPad Pro 11-inch
|
||||
"iPad13,4": "iPad Pro (11-inch) (3rd generation)",
|
||||
"iPad13,5": "iPad Pro (11-inch) (3rd generation)",
|
||||
"iPad13,6": "iPad Pro (11-inch) (3rd generation)",
|
||||
"iPad13,7": "iPad Pro (11-inch) (3rd generation)",
|
||||
"iPad14,3": "iPad Pro (11-inch) (4th generation)",
|
||||
"iPad14,4": "iPad Pro (11-inch) (4th generation)",
|
||||
|
||||
// iPad Air
|
||||
"iPad13,1": "iPad Air (4th generation)",
|
||||
"iPad13,2": "iPad Air (4th generation)",
|
||||
"iPad13,16": "iPad Air (5th generation)",
|
||||
"iPad13,17": "iPad Air (5th generation)",
|
||||
|
||||
// iPad mini
|
||||
"iPad14,1": "iPad mini (6th generation)",
|
||||
"iPad14,2": "iPad mini (6th generation)",
|
||||
|
||||
// iPad (regular)
|
||||
"iPad12,1": "iPad (9th generation)",
|
||||
"iPad12,2": "iPad (9th generation)",
|
||||
"iPad13,18": "iPad (10th generation)",
|
||||
"iPad13,19": "iPad (10th generation)",
|
||||
]
|
||||
|
||||
// Try iPhone mappings first
|
||||
if let modelName = iPhoneMappings[identifier] {
|
||||
return modelName
|
||||
}
|
||||
|
||||
// Try iPad mappings
|
||||
if let modelName = iPadMappings[identifier] {
|
||||
return modelName
|
||||
}
|
||||
|
||||
// Check for simulator
|
||||
if identifier == "x86_64" || identifier == "i386" {
|
||||
return "Simulator"
|
||||
}
|
||||
|
||||
// Return raw identifier if no mapping found
|
||||
return identifier
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
//
|
||||
// ModelItem.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Structure representing a model item with download state information.
|
||||
* Used for tracking model availability and download progress in the benchmark interface.
|
||||
*/
|
||||
struct ModelItem: Identifiable, Equatable {
|
||||
let id = UUID()
|
||||
let modelId: String
|
||||
let displayName: String
|
||||
let isLocal: Bool
|
||||
let localPath: String?
|
||||
let size: Int64?
|
||||
let downloadState: DownloadState
|
||||
|
||||
enum DownloadState: Equatable {
|
||||
case notStarted
|
||||
case downloading(progress: Double)
|
||||
case completed
|
||||
case failed(error: String)
|
||||
case paused
|
||||
}
|
||||
|
||||
static func == (lhs: ModelItem, rhs: ModelItem) -> Bool {
|
||||
return lhs.modelId == rhs.modelId
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
//
|
||||
// ModelListManager.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Manager class for integrating with ModelListViewModel to provide
|
||||
* downloaded models for benchmark testing.
|
||||
*/
|
||||
class ModelListManager {
|
||||
static let shared = ModelListManager()
|
||||
|
||||
private let modelListViewModel = ModelListViewModel()
|
||||
|
||||
private init() {}
|
||||
|
||||
/// Loads available models, filtering for downloaded models suitable for benchmarking
|
||||
/// - Returns: Array of downloaded ModelInfo objects
|
||||
/// - Throws: Error if model loading fails
|
||||
func loadModels() async throws -> [ModelInfo] {
|
||||
// Ensure models are loaded from the view model
|
||||
await modelListViewModel.fetchModels()
|
||||
|
||||
// Return only downloaded models that are available for benchmark
|
||||
return modelListViewModel.models.filter { $0.isDownloaded }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
//
|
||||
// ProgressType.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Enumeration representing different stages of benchmark execution progress.
|
||||
* Used to track and display the current state of benchmark operations.
|
||||
*/
|
||||
enum ProgressType: Int, CaseIterable {
|
||||
case unknown = 0
|
||||
case initializing
|
||||
case warmingUp
|
||||
case runningTest
|
||||
case processingResults
|
||||
case completed
|
||||
case stopping
|
||||
|
||||
var description: String {
|
||||
switch self {
|
||||
case .unknown: return "Unknown"
|
||||
case .initializing: return "Initializing benchmark..."
|
||||
case .warmingUp: return "Warming up..."
|
||||
case .runningTest: return "Running test"
|
||||
case .processingResults: return "Processing results..."
|
||||
case .completed: return "All tests completed"
|
||||
case .stopping: return "Stopping benchmark..."
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
//
|
||||
// RuntimeParameters.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Configuration parameters for benchmark runtime environment.
|
||||
* Defines hardware and execution settings for benchmark tests.
|
||||
*/
|
||||
struct RuntimeParameters {
|
||||
let backends: [Int]
|
||||
let threads: [Int]
|
||||
let useMmap: Bool
|
||||
let power: [Int]
|
||||
let precision: [Int]
|
||||
let memory: [Int]
|
||||
let dynamicOption: [Int]
|
||||
|
||||
static let `default` = RuntimeParameters(
|
||||
backends: [0], // CPU
|
||||
threads: [4],
|
||||
useMmap: false,
|
||||
power: [0],
|
||||
precision: [2], // Low precision
|
||||
memory: [2], // Low memory
|
||||
dynamicOption: [0]
|
||||
)
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
//
|
||||
// SpeedStatistics.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Structure containing statistical analysis of benchmark speed metrics.
|
||||
* Provides average, standard deviation, and descriptive label for performance data.
|
||||
*/
|
||||
struct SpeedStatistics {
|
||||
let average: Double
|
||||
let stdev: Double
|
||||
let label: String
|
||||
|
||||
init(average: Double, stdev: Double, label: String) {
|
||||
self.average = average
|
||||
self.stdev = stdev
|
||||
self.label = label
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
//
|
||||
// TestInstance.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import Combine
|
||||
|
||||
/**
|
||||
* Observable class representing a single benchmark test instance.
|
||||
* Contains test configuration parameters and stores timing results.
|
||||
*/
|
||||
class TestInstance: ObservableObject, Identifiable {
|
||||
let id = UUID()
|
||||
let modelConfigFile: String
|
||||
let modelType: String
|
||||
let modelSize: Int64
|
||||
let threads: Int
|
||||
let useMmap: Bool
|
||||
let nPrompt: Int
|
||||
let nGenerate: Int
|
||||
let backend: Int
|
||||
let precision: Int
|
||||
let power: Int
|
||||
let memory: Int
|
||||
let dynamicOption: Int
|
||||
|
||||
@Published var prefillUs: [Int64] = []
|
||||
@Published var decodeUs: [Int64] = []
|
||||
@Published var samplesUs: [Int64] = []
|
||||
|
||||
init(modelConfigFile: String,
|
||||
modelType: String,
|
||||
modelSize: Int64 = 0,
|
||||
threads: Int,
|
||||
useMmap: Bool,
|
||||
nPrompt: Int,
|
||||
nGenerate: Int,
|
||||
backend: Int,
|
||||
precision: Int,
|
||||
power: Int,
|
||||
memory: Int,
|
||||
dynamicOption: Int) {
|
||||
self.modelConfigFile = modelConfigFile
|
||||
self.modelType = modelType
|
||||
self.modelSize = modelSize
|
||||
self.threads = threads
|
||||
self.useMmap = useMmap
|
||||
self.nPrompt = nPrompt
|
||||
self.nGenerate = nGenerate
|
||||
self.backend = backend
|
||||
self.precision = precision
|
||||
self.power = power
|
||||
self.memory = memory
|
||||
self.dynamicOption = dynamicOption
|
||||
}
|
||||
|
||||
/// Calculates tokens per second from timing data
|
||||
/// - Parameters:
|
||||
/// - tokens: Number of tokens processed
|
||||
/// - timesUs: Array of timing measurements in microseconds
|
||||
/// - Returns: Array of tokens per second calculations
|
||||
func getTokensPerSecond(tokens: Int, timesUs: [Int64]) -> [Double] {
|
||||
return timesUs.compactMap { timeUs in
|
||||
guard timeUs > 0 else { return 0.0 }
|
||||
return Double(tokens) * 1_000_000.0 / Double(timeUs)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
//
|
||||
// TestParameters.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
/**
|
||||
* Configuration parameters for benchmark test execution.
|
||||
* Defines test scenarios including prompt sizes, generation lengths, and repetition counts.
|
||||
*/
|
||||
struct TestParameters {
|
||||
let nPrompt: [Int]
|
||||
let nGenerate: [Int]
|
||||
let nPrompGen: [(Int, Int)]
|
||||
let nRepeat: [Int]
|
||||
let kvCache: String
|
||||
let loadTime: String
|
||||
|
||||
static let `default` = TestParameters(
|
||||
nPrompt: [256, 512],
|
||||
nGenerate: [64, 128],
|
||||
nPrompGen: [(256, 64), (512, 128)],
|
||||
nRepeat: [3], // Reduced for mobile
|
||||
kvCache: "false", // llama-bench style test by default
|
||||
loadTime: "false"
|
||||
)
|
||||
}
|
|
@ -0,0 +1,414 @@
|
|||
//
|
||||
// BenchmarkService.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import Combine
|
||||
|
||||
/**
|
||||
* Protocol defining callback methods for benchmark execution events.
|
||||
* Provides progress updates, completion notifications, and error handling.
|
||||
*/
|
||||
protocol BenchmarkCallback: AnyObject {
|
||||
func onProgress(_ progress: BenchmarkProgress)
|
||||
func onComplete(_ result: BenchmarkResult)
|
||||
func onBenchmarkError(_ errorCode: Int, _ message: String)
|
||||
}
|
||||
|
||||
/**
|
||||
* Singleton service class responsible for managing benchmark operations.
|
||||
* Coordinates with LLMInferenceEngineWrapper to execute performance tests
|
||||
* and provides real-time progress updates through callback mechanisms.
|
||||
*/
|
||||
class BenchmarkService: ObservableObject {
|
||||
|
||||
// MARK: - Singleton & Properties
|
||||
|
||||
static let shared = BenchmarkService()
|
||||
|
||||
@Published private(set) var isRunning = false
|
||||
private var shouldStop = false
|
||||
private var currentTask: Task<Void, Never>?
|
||||
|
||||
// Real LLM inference engine - using actual MNN LLM wrapper
|
||||
private var llmEngine: LLMInferenceEngineWrapper?
|
||||
private var currentModelId: String?
|
||||
|
||||
private init() {}
|
||||
|
||||
// MARK: - Public Interface
|
||||
|
||||
/// Initiates benchmark execution with specified parameters and callback handler
|
||||
/// - Parameters:
|
||||
/// - modelId: Identifier for the model to benchmark
|
||||
/// - callback: Callback handler for progress and completion events
|
||||
/// - runtimeParams: Runtime configuration parameters
|
||||
/// - testParams: Test scenario parameters
|
||||
func runBenchmark(
|
||||
modelId: String,
|
||||
callback: BenchmarkCallback,
|
||||
runtimeParams: RuntimeParameters = .default,
|
||||
testParams: TestParameters = .default
|
||||
) {
|
||||
guard !isRunning else {
|
||||
callback.onBenchmarkError(BenchmarkErrorCode.benchmarkRunning.rawValue, "Benchmark is already running")
|
||||
return
|
||||
}
|
||||
|
||||
guard let engine = llmEngine, engine.isModelReady() else {
|
||||
callback.onBenchmarkError(BenchmarkErrorCode.modelNotInitialized.rawValue, "Model is not initialized or not ready")
|
||||
return
|
||||
}
|
||||
|
||||
isRunning = true
|
||||
shouldStop = false
|
||||
|
||||
currentTask = Task {
|
||||
await performBenchmark(
|
||||
engine: engine,
|
||||
modelId: modelId,
|
||||
callback: callback,
|
||||
runtimeParams: runtimeParams,
|
||||
testParams: testParams
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops the currently running benchmark operation
|
||||
func stopBenchmark() {
|
||||
shouldStop = true
|
||||
llmEngine?.stopBenchmark()
|
||||
currentTask?.cancel()
|
||||
isRunning = false
|
||||
}
|
||||
|
||||
/// Checks if the model is properly initialized and ready for benchmarking
|
||||
/// - Returns: True if model is ready, false otherwise
|
||||
func isModelInitialized() -> Bool {
|
||||
return llmEngine != nil && llmEngine!.isModelReady()
|
||||
}
|
||||
|
||||
/// Initializes a model for benchmark testing
|
||||
/// - Parameters:
|
||||
/// - modelId: Identifier for the model
|
||||
/// - modelPath: File system path to the model
|
||||
/// - Returns: True if initialization succeeded, false otherwise
|
||||
func initializeModel(modelId: String, modelPath: String) async -> Bool {
|
||||
return await withCheckedContinuation { continuation in
|
||||
// Release existing engine if any
|
||||
llmEngine = nil
|
||||
currentModelId = nil
|
||||
|
||||
// Create new LLM inference engine
|
||||
llmEngine = LLMInferenceEngineWrapper(modelPath: modelPath) { success in
|
||||
if success {
|
||||
self.currentModelId = modelId
|
||||
print("BenchmarkService: Model \(modelId) initialized successfully")
|
||||
} else {
|
||||
self.llmEngine = nil
|
||||
print("BenchmarkService: Failed to initialize model \(modelId)")
|
||||
}
|
||||
continuation.resume(returning: success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves information about the currently loaded model
|
||||
/// - Returns: Model information string, or nil if no model is loaded
|
||||
func getModelInfo() -> String? {
|
||||
guard let modelId = currentModelId else { return nil }
|
||||
return "Model: \(modelId), Engine: MNN LLM"
|
||||
}
|
||||
|
||||
/// Releases the current model and frees associated resources
|
||||
func releaseModel() {
|
||||
llmEngine = nil
|
||||
currentModelId = nil
|
||||
}
|
||||
|
||||
// MARK: - Benchmark Execution
|
||||
|
||||
/// Performs the actual benchmark execution with progress tracking
|
||||
/// - Parameters:
|
||||
/// - engine: LLM inference engine instance
|
||||
/// - modelId: Model identifier
|
||||
/// - callback: Progress and completion callback handler
|
||||
/// - runtimeParams: Runtime configuration
|
||||
/// - testParams: Test parameters
|
||||
private func performBenchmark(
|
||||
engine: LLMInferenceEngineWrapper,
|
||||
modelId: String,
|
||||
callback: BenchmarkCallback,
|
||||
runtimeParams: RuntimeParameters,
|
||||
testParams: TestParameters
|
||||
) async {
|
||||
do {
|
||||
let instances = generateTestInstances(runtimeParams: runtimeParams, testParams: testParams)
|
||||
|
||||
var completedInstances = 0
|
||||
let totalInstances = instances.count
|
||||
|
||||
for instance in instances {
|
||||
if shouldStop {
|
||||
await MainActor.run {
|
||||
callback.onBenchmarkError(BenchmarkErrorCode.benchmarkStopped.rawValue, "Benchmark stopped by user")
|
||||
self.isRunning = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Create TestInstance for current configuration
|
||||
let testInstance = TestInstance(
|
||||
modelConfigFile: instance.configPath,
|
||||
modelType: modelId,
|
||||
modelSize: 0, // Will be calculated if needed
|
||||
threads: instance.threads,
|
||||
useMmap: instance.useMmap,
|
||||
nPrompt: instance.nPrompt,
|
||||
nGenerate: instance.nGenerate,
|
||||
backend: instance.backend,
|
||||
precision: instance.precision,
|
||||
power: instance.power,
|
||||
memory: instance.memory,
|
||||
dynamicOption: instance.dynamicOption
|
||||
)
|
||||
|
||||
// Update overall progress
|
||||
let progress = (completedInstances * 100) / totalInstances
|
||||
let statusMsg = "Running test \(completedInstances + 1)/\(totalInstances): pp\(instance.nPrompt)+tg\(instance.nGenerate)"
|
||||
|
||||
await MainActor.run {
|
||||
callback.onProgress(BenchmarkProgress(
|
||||
progress: progress,
|
||||
statusMessage: statusMsg,
|
||||
progressType: .runningTest,
|
||||
currentIteration: completedInstances + 1,
|
||||
totalIterations: totalInstances,
|
||||
nPrompt: instance.nPrompt,
|
||||
nGenerate: instance.nGenerate
|
||||
))
|
||||
}
|
||||
|
||||
// Execute benchmark using LLMInferenceEngineWrapper
|
||||
let result = await runOfficialBenchmark(
|
||||
engine: engine,
|
||||
instance: instance,
|
||||
testInstance: testInstance,
|
||||
progressCallback: { progress in
|
||||
await MainActor.run {
|
||||
callback.onProgress(progress)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if result.success {
|
||||
completedInstances += 1
|
||||
|
||||
// Only call onComplete for the last test instance
|
||||
if completedInstances == totalInstances {
|
||||
await MainActor.run {
|
||||
callback.onComplete(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await MainActor.run {
|
||||
callback.onBenchmarkError(BenchmarkErrorCode.testInstanceFailed.rawValue, result.errorMessage ?? "Test failed")
|
||||
self.isRunning = false
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.isRunning = false
|
||||
}
|
||||
|
||||
} catch {
|
||||
await MainActor.run {
|
||||
callback.onBenchmarkError(BenchmarkErrorCode.nativeError.rawValue, error.localizedDescription)
|
||||
self.isRunning = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes a single benchmark test using the official MNN LLM benchmark interface
|
||||
/// - Parameters:
|
||||
/// - engine: LLM inference engine
|
||||
/// - instance: Test configuration
|
||||
/// - testInstance: Test instance to populate with results
|
||||
/// - progressCallback: Callback for progress updates
|
||||
/// - Returns: Benchmark result with success status and timing data
|
||||
private func runOfficialBenchmark(
|
||||
engine: LLMInferenceEngineWrapper,
|
||||
instance: TestConfig,
|
||||
testInstance: TestInstance,
|
||||
progressCallback: @escaping (BenchmarkProgress) async -> Void
|
||||
) async -> BenchmarkResult {
|
||||
|
||||
return await withCheckedContinuation { continuation in
|
||||
var hasResumed = false
|
||||
|
||||
engine.runOfficialBenchmark(
|
||||
withBackend: instance.backend,
|
||||
threads: instance.threads,
|
||||
useMmap: instance.useMmap,
|
||||
power: instance.power,
|
||||
precision: instance.precision,
|
||||
memory: instance.memory,
|
||||
dynamicOption: instance.dynamicOption,
|
||||
nPrompt: instance.nPrompt,
|
||||
nGenerate: instance.nGenerate,
|
||||
nRepeat: instance.nRepeat,
|
||||
kvCache: instance.kvCache,
|
||||
progressCallback: { [self] progressInfo in
|
||||
// Convert Objective-C BenchmarkProgressInfo to Swift BenchmarkProgress
|
||||
let swiftProgress = BenchmarkProgress(
|
||||
progress: Int(progressInfo.progress),
|
||||
statusMessage: progressInfo.statusMessage,
|
||||
progressType: convertProgressType(progressInfo.progressType),
|
||||
currentIteration: Int(progressInfo.currentIteration),
|
||||
totalIterations: Int(progressInfo.totalIterations),
|
||||
nPrompt: Int(progressInfo.nPrompt),
|
||||
nGenerate: Int(progressInfo.nGenerate),
|
||||
runTimeSeconds: progressInfo.runTimeSeconds,
|
||||
prefillTimeSeconds: progressInfo.prefillTimeSeconds,
|
||||
decodeTimeSeconds: progressInfo.decodeTimeSeconds,
|
||||
prefillSpeed: progressInfo.prefillSpeed,
|
||||
decodeSpeed: progressInfo.decodeSpeed
|
||||
)
|
||||
|
||||
Task {
|
||||
await progressCallback(swiftProgress)
|
||||
}
|
||||
},
|
||||
errorCallback: { errorMessage in
|
||||
if !hasResumed {
|
||||
hasResumed = true
|
||||
let result = BenchmarkResult(
|
||||
testInstance: testInstance,
|
||||
success: false,
|
||||
errorMessage: errorMessage
|
||||
)
|
||||
continuation.resume(returning: result)
|
||||
}
|
||||
},
|
||||
iterationCompleteCallback: { detailedStats in
|
||||
// Log detailed stats if needed
|
||||
print("Benchmark iteration complete: \(detailedStats)")
|
||||
},
|
||||
completeCallback: { benchmarkResult in
|
||||
if !hasResumed {
|
||||
hasResumed = true
|
||||
|
||||
// Update test instance with timing results
|
||||
testInstance.prefillUs = benchmarkResult.prefillTimesUs.map { $0.int64Value }
|
||||
testInstance.decodeUs = benchmarkResult.decodeTimesUs.map { $0.int64Value }
|
||||
testInstance.samplesUs = benchmarkResult.sampleTimesUs.map { $0.int64Value }
|
||||
|
||||
let result = BenchmarkResult(
|
||||
testInstance: testInstance,
|
||||
success: benchmarkResult.success,
|
||||
errorMessage: benchmarkResult.errorMessage
|
||||
)
|
||||
continuation.resume(returning: result)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helper Methods & Configuration
|
||||
|
||||
/// Converts Objective-C progress type to Swift enum
|
||||
/// - Parameter objcType: Objective-C progress type
|
||||
/// - Returns: Corresponding Swift ProgressType
|
||||
private func convertProgressType(_ objcType: BenchmarkProgressType) -> ProgressType {
|
||||
switch objcType {
|
||||
case .unknown:
|
||||
return .unknown
|
||||
case .initializing:
|
||||
return .initializing
|
||||
case .warmingUp:
|
||||
return .warmingUp
|
||||
case .runningTest:
|
||||
return .runningTest
|
||||
case .processingResults:
|
||||
return .processingResults
|
||||
case .completed:
|
||||
return .completed
|
||||
case .stopping:
|
||||
return .stopping
|
||||
@unknown default:
|
||||
return .unknown
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates test instances by combining runtime and test parameters
|
||||
/// - Parameters:
|
||||
/// - runtimeParams: Runtime configuration parameters
|
||||
/// - testParams: Test scenario parameters
|
||||
/// - Returns: Array of test configurations for execution
|
||||
private func generateTestInstances(
|
||||
runtimeParams: RuntimeParameters,
|
||||
testParams: TestParameters
|
||||
) -> [TestConfig] {
|
||||
var instances: [TestConfig] = []
|
||||
|
||||
for backend in runtimeParams.backends {
|
||||
for thread in runtimeParams.threads {
|
||||
for power in runtimeParams.power {
|
||||
for precision in runtimeParams.precision {
|
||||
for memory in runtimeParams.memory {
|
||||
for dynamicOption in runtimeParams.dynamicOption {
|
||||
for repeatCount in testParams.nRepeat {
|
||||
for (nPrompt, nGenerate) in testParams.nPrompGen {
|
||||
instances.append(TestConfig(
|
||||
configPath: "", // Will be set based on model
|
||||
backend: backend,
|
||||
threads: thread,
|
||||
useMmap: runtimeParams.useMmap,
|
||||
power: power,
|
||||
precision: precision,
|
||||
memory: memory,
|
||||
dynamicOption: dynamicOption,
|
||||
nPrompt: nPrompt,
|
||||
nGenerate: nGenerate,
|
||||
nRepeat: repeatCount,
|
||||
kvCache: testParams.kvCache == "true"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return instances
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Test Configuration
|
||||
|
||||
/**
|
||||
* Structure containing configuration parameters for a single benchmark test.
|
||||
* Combines runtime settings and test parameters into a complete test specification.
|
||||
*/
|
||||
struct TestConfig {
|
||||
let configPath: String
|
||||
let backend: Int
|
||||
let threads: Int
|
||||
let useMmap: Bool
|
||||
let power: Int
|
||||
let precision: Int
|
||||
let memory: Int
|
||||
let dynamicOption: Int
|
||||
let nPrompt: Int
|
||||
let nGenerate: Int
|
||||
let nRepeat: Int
|
||||
let kvCache: Bool
|
||||
}
|
|
@ -0,0 +1,444 @@
|
|||
//
|
||||
// BenchmarkViewModel.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import SwiftUI
|
||||
import Combine
|
||||
|
||||
/**
|
||||
* ViewModel for managing benchmark operations including model selection, test execution,
|
||||
* progress tracking, and result management. Handles communication with BenchmarkService
|
||||
* and provides UI state management for the benchmark interface.
|
||||
*/
|
||||
@MainActor
|
||||
class BenchmarkViewModel: ObservableObject {
|
||||
|
||||
// MARK: - Published Properties
|
||||
|
||||
@Published var isLoading = false
|
||||
@Published var isRunning = false
|
||||
@Published var showProgressBar = false
|
||||
@Published var showResults = false
|
||||
@Published var showError = false
|
||||
|
||||
@Published var selectedModel: ModelInfo?
|
||||
@Published var availableModels: [ModelInfo] = []
|
||||
@Published var currentProgress: BenchmarkProgress?
|
||||
@Published var benchmarkResults: BenchmarkResults?
|
||||
@Published var errorMessage: String = ""
|
||||
@Published var statusMessage: String = ""
|
||||
|
||||
@Published var startButtonText = String(localized: "Start Test")
|
||||
@Published var isStartButtonEnabled = true
|
||||
|
||||
// MARK: - Private Properties
|
||||
|
||||
private let benchmarkService = BenchmarkService.shared
|
||||
private let resultsHelper = BenchmarkResultsHelper.shared
|
||||
private var cancellables = Set<AnyCancellable>()
|
||||
|
||||
// Model list manager for getting local models
|
||||
private let modelListManager = ModelListManager.shared
|
||||
|
||||
// MARK: - Initialization & Setup
|
||||
|
||||
init() {
|
||||
setupBindings()
|
||||
loadAvailableModels()
|
||||
}
|
||||
|
||||
/// Sets up reactive bindings between service and view model
|
||||
private func setupBindings() {
|
||||
benchmarkService.$isRunning
|
||||
.receive(on: DispatchQueue.main)
|
||||
.assign(to: \.isRunning, on: self)
|
||||
.store(in: &cancellables)
|
||||
|
||||
// Update button text based on running state
|
||||
benchmarkService.$isRunning
|
||||
.receive(on: DispatchQueue.main)
|
||||
.map { isRunning in
|
||||
isRunning ? String(localized: "Stop Test") : String(localized: "Start Test")
|
||||
}
|
||||
.assign(to: \.startButtonText, on: self)
|
||||
.store(in: &cancellables)
|
||||
}
|
||||
|
||||
/// Loads available models from ModelListManager, filtering for downloaded models only
|
||||
private func loadAvailableModels() {
|
||||
Task {
|
||||
isLoading = true
|
||||
|
||||
do {
|
||||
// Get all models from ModelListManager
|
||||
let allModels = try await modelListManager.loadModels()
|
||||
|
||||
// Filter only downloaded models that are available locally
|
||||
availableModels = allModels.filter { model in
|
||||
model.isDownloaded && model.localPath != nil
|
||||
}
|
||||
|
||||
print("BenchmarkViewModel: Loaded \(availableModels.count) available local models")
|
||||
|
||||
} catch {
|
||||
showErrorMessage("Failed to load models: \(error.localizedDescription)")
|
||||
}
|
||||
|
||||
isLoading = false
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Public Action Handlers
|
||||
|
||||
/// Handles start/stop benchmark button taps
|
||||
func onStartBenchmarkTapped() {
|
||||
if !isRunning {
|
||||
startBenchmark()
|
||||
} else {
|
||||
showStopConfirmationAlert()
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles benchmark stop confirmation
|
||||
func onStopBenchmarkTapped() {
|
||||
stopBenchmark()
|
||||
}
|
||||
|
||||
/// Handles model selection from dropdown
|
||||
func onModelSelected(_ model: ModelInfo) {
|
||||
selectedModel = model
|
||||
}
|
||||
|
||||
/// Handles result deletion and cleanup
|
||||
func onDeleteResultTapped() {
|
||||
benchmarkResults = nil
|
||||
showResults = false
|
||||
hideStatus()
|
||||
|
||||
// Release model to free memory
|
||||
benchmarkService.releaseModel()
|
||||
}
|
||||
|
||||
/// Placeholder for future result submission functionality
|
||||
func onSubmitResultTapped() {
|
||||
// Implementation for submitting results (if needed)
|
||||
// This could involve sharing or uploading results
|
||||
}
|
||||
|
||||
// MARK: - Benchmark Execution
|
||||
|
||||
/// Initiates benchmark test with selected model and configured parameters
|
||||
private func startBenchmark() {
|
||||
guard let model = selectedModel else {
|
||||
showErrorMessage("Please select a model first")
|
||||
return
|
||||
}
|
||||
|
||||
guard model.isDownloaded else {
|
||||
showErrorMessage("Selected model is not downloaded or path is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
onBenchmarkStarted()
|
||||
|
||||
Task {
|
||||
// Initialize model if needed
|
||||
let initialized = await benchmarkService.initializeModel(
|
||||
modelId: model.id,
|
||||
modelPath: model.localPath
|
||||
)
|
||||
|
||||
guard initialized else {
|
||||
showErrorMessage("Failed to initialize model")
|
||||
resetUIState()
|
||||
return
|
||||
}
|
||||
|
||||
// Start memory monitoring
|
||||
MemoryMonitor.shared.start()
|
||||
|
||||
// Start benchmark with optimized parameters for mobile devices
|
||||
benchmarkService.runBenchmark(
|
||||
modelId: model.id,
|
||||
callback: self,
|
||||
runtimeParams: createRuntimeParameters(),
|
||||
testParams: createTestParameters()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates runtime parameters optimized for iOS devices
|
||||
private func createRuntimeParameters() -> RuntimeParameters {
|
||||
return RuntimeParameters(
|
||||
backends: [0], // CPU backend
|
||||
threads: [4], // 4 threads for most iOS devices
|
||||
useMmap: false, // Memory mapping disabled for iOS
|
||||
power: [0], // Normal power mode
|
||||
precision: [2], // Low precision for better performance
|
||||
memory: [2], // Low memory usage
|
||||
dynamicOption: [0] // No dynamic optimization
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates test parameters suitable for mobile benchmarking
|
||||
private func createTestParameters() -> TestParameters {
|
||||
return TestParameters(
|
||||
nPrompt: [256, 512], // Smaller prompt sizes for mobile
|
||||
nGenerate: [64, 128], // Smaller generation sizes
|
||||
nPrompGen: [(256, 64), (512, 128)], // Combined test cases
|
||||
nRepeat: [3], // Fewer repetitions for faster testing
|
||||
kvCache: "false", // Disable KV cache by default
|
||||
loadTime: "false"
|
||||
)
|
||||
}
|
||||
|
||||
/// Stops the currently running benchmark
|
||||
private func stopBenchmark() {
|
||||
updateStatus("Stopping benchmark...")
|
||||
benchmarkService.stopBenchmark()
|
||||
MemoryMonitor.shared.stop()
|
||||
}
|
||||
|
||||
// MARK: - UI State Management
|
||||
|
||||
/// Updates UI state when benchmark starts
|
||||
private func onBenchmarkStarted() {
|
||||
isStartButtonEnabled = true
|
||||
showProgressBar = true
|
||||
showResults = false
|
||||
updateStatus("Initializing benchmark...")
|
||||
}
|
||||
|
||||
/// Resets UI to initial state
|
||||
private func resetUIState() {
|
||||
isStartButtonEnabled = true
|
||||
showProgressBar = false
|
||||
hideStatus()
|
||||
showResults = false
|
||||
MemoryMonitor.shared.stop()
|
||||
}
|
||||
|
||||
/// Updates status message display
|
||||
private func updateStatus(_ message: String) {
|
||||
statusMessage = message
|
||||
}
|
||||
|
||||
/// Hides status message
|
||||
private func hideStatus() {
|
||||
statusMessage = ""
|
||||
}
|
||||
|
||||
/// Shows error message alert
|
||||
private func showErrorMessage(_ message: String) {
|
||||
errorMessage = message
|
||||
showError = true
|
||||
}
|
||||
|
||||
/// Placeholder for stop confirmation alert (handled in View)
|
||||
private func showStopConfirmationAlert() {
|
||||
// This will be handled in the View with an alert
|
||||
}
|
||||
|
||||
/// Formats progress messages with appropriate status text based on progress type
|
||||
private func formatProgressMessage(_ progress: BenchmarkProgress) -> BenchmarkProgress {
|
||||
let formattedMessage: String
|
||||
|
||||
switch progress.progressType {
|
||||
case .initializing:
|
||||
formattedMessage = "Initializing benchmark..."
|
||||
case .warmingUp:
|
||||
formattedMessage = "Warming up..."
|
||||
case .runningTest:
|
||||
formattedMessage = "Running test \(progress.currentIteration)/\(progress.totalIterations)"
|
||||
case .processingResults:
|
||||
formattedMessage = "Processing results..."
|
||||
case .completed:
|
||||
formattedMessage = "All tests completed"
|
||||
case .stopping:
|
||||
formattedMessage = "Stopping benchmark..."
|
||||
default:
|
||||
formattedMessage = progress.statusMessage
|
||||
}
|
||||
|
||||
return BenchmarkProgress(
|
||||
progress: progress.progress,
|
||||
statusMessage: formattedMessage,
|
||||
progressType: progress.progressType,
|
||||
currentIteration: progress.currentIteration,
|
||||
totalIterations: progress.totalIterations,
|
||||
nPrompt: progress.nPrompt,
|
||||
nGenerate: progress.nGenerate,
|
||||
runTimeSeconds: progress.runTimeSeconds,
|
||||
prefillTimeSeconds: progress.prefillTimeSeconds,
|
||||
decodeTimeSeconds: progress.decodeTimeSeconds,
|
||||
prefillSpeed: progress.prefillSpeed,
|
||||
decodeSpeed: progress.decodeSpeed
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - BenchmarkCallback Implementation
|
||||
|
||||
extension BenchmarkViewModel: BenchmarkCallback {
|
||||
|
||||
/// Handles progress updates from benchmark service
|
||||
func onProgress(_ progress: BenchmarkProgress) {
|
||||
let formattedProgress = formatProgressMessage(progress)
|
||||
currentProgress = formattedProgress
|
||||
updateStatus(formattedProgress.statusMessage)
|
||||
}
|
||||
|
||||
/// Handles benchmark completion with results processing
|
||||
func onComplete(_ result: BenchmarkResult) {
|
||||
guard let model = selectedModel else { return }
|
||||
|
||||
updateStatus("Processing results...")
|
||||
|
||||
// Create comprehensive benchmark results
|
||||
let results = BenchmarkResults(
|
||||
modelDisplayName: model.modelName,
|
||||
maxMemoryKb: MemoryMonitor.shared.getMaxMemoryKb(),
|
||||
testResults: [result.testInstance],
|
||||
timestamp: DateFormatter.benchmarkTimestamp.string(from: Date())
|
||||
)
|
||||
|
||||
benchmarkResults = results
|
||||
showResults = true
|
||||
|
||||
// Only stop memory monitoring if benchmark is no longer running (all tests completed)
|
||||
if !isRunning {
|
||||
// Stop memory monitoring
|
||||
MemoryMonitor.shared.stop()
|
||||
}
|
||||
|
||||
// Always hide status after processing results
|
||||
hideStatus()
|
||||
|
||||
print("BenchmarkViewModel: Benchmark completed successfully for model: \(model.modelName)")
|
||||
}
|
||||
|
||||
/// Handles benchmark errors with user-friendly error messages
|
||||
func onBenchmarkError(_ errorCode: Int, _ message: String) {
|
||||
let errorCodeName = BenchmarkErrorCode(rawValue: errorCode)?.description ?? "Unknown"
|
||||
showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)")
|
||||
resetUIState()
|
||||
print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Memory Monitoring
|
||||
|
||||
/**
|
||||
* Singleton class for monitoring memory usage during benchmark execution.
|
||||
* Tracks current and peak memory consumption using system APIs.
|
||||
*/
|
||||
class MemoryMonitor: ObservableObject {
|
||||
|
||||
static let shared = MemoryMonitor()
|
||||
|
||||
@Published private(set) var currentMemoryKb: Int64 = 0
|
||||
private var maxMemoryKb: Int64 = 0
|
||||
private var isMonitoring = false
|
||||
private var monitoringTask: Task<Void, Never>?
|
||||
|
||||
private init() {}
|
||||
|
||||
/// Starts continuous memory monitoring
|
||||
func start() {
|
||||
guard !isMonitoring else { return }
|
||||
|
||||
isMonitoring = true
|
||||
maxMemoryKb = 0
|
||||
|
||||
monitoringTask = Task {
|
||||
while isMonitoring && !Task.isCancelled {
|
||||
await updateMemoryUsage()
|
||||
try? await Task.sleep(nanoseconds: 500_000_000) // 0.5 seconds
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops memory monitoring
|
||||
func stop() {
|
||||
isMonitoring = false
|
||||
monitoringTask?.cancel()
|
||||
monitoringTask = nil
|
||||
}
|
||||
|
||||
/// Resets memory tracking counters
|
||||
func reset() {
|
||||
maxMemoryKb = 0
|
||||
currentMemoryKb = 0
|
||||
}
|
||||
|
||||
/// Returns the maximum memory usage recorded during monitoring
|
||||
func getMaxMemoryKb() -> Int64 {
|
||||
return maxMemoryKb
|
||||
}
|
||||
|
||||
/// Updates current memory usage and tracks maximum
|
||||
@MainActor
|
||||
private func updateMemoryUsage() {
|
||||
let memoryUsage = getCurrentMemoryUsage()
|
||||
currentMemoryKb = memoryUsage
|
||||
maxMemoryKb = max(maxMemoryKb, memoryUsage)
|
||||
}
|
||||
|
||||
/// Gets current memory usage from system using mach task info
|
||||
private func getCurrentMemoryUsage() -> Int64 {
|
||||
var info = mach_task_basic_info()
|
||||
var count = mach_msg_type_number_t(MemoryLayout<mach_task_basic_info>.size) / 4
|
||||
|
||||
let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) {
|
||||
$0.withMemoryRebound(to: integer_t.self, capacity: 1) {
|
||||
task_info(mach_task_self_,
|
||||
task_flavor_t(MACH_TASK_BASIC_INFO),
|
||||
$0,
|
||||
&count)
|
||||
}
|
||||
}
|
||||
|
||||
if kerr == KERN_SUCCESS {
|
||||
return Int64(info.resident_size) / 1024 // Convert to KB
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
/// Extension providing user-friendly descriptions for benchmark error codes
|
||||
extension BenchmarkErrorCode {
|
||||
var description: String {
|
||||
switch self {
|
||||
case .benchmarkFailedUnknown:
|
||||
return "Unknown Error"
|
||||
case .testInstanceFailed:
|
||||
return "Test Failed"
|
||||
case .modelNotInitialized:
|
||||
return "Model Not Ready"
|
||||
case .benchmarkRunning:
|
||||
return "Already Running"
|
||||
case .benchmarkStopped:
|
||||
return "Stopped"
|
||||
case .nativeError:
|
||||
return "Native Error"
|
||||
case .modelError:
|
||||
return "Model Error"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension providing formatted timestamp for benchmark results
|
||||
extension DateFormatter {
|
||||
static let benchmarkTimestamp: DateFormatter = {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyy/M/dd HH:mm:ss"
|
||||
return formatter
|
||||
}()
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
//
|
||||
// MetricCard.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Reusable metric display card component.
|
||||
* Shows performance metrics with icon, title, and value in a compact format.
|
||||
*/
|
||||
struct MetricCard: View {
|
||||
let title: String
|
||||
let value: String
|
||||
let icon: String
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack(spacing: 6) {
|
||||
Image(systemName: icon)
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkAccent)
|
||||
|
||||
Text(title)
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
.lineLimit(1)
|
||||
}
|
||||
|
||||
Text(value)
|
||||
.font(.system(size: 14, weight: .semibold))
|
||||
.foregroundColor(.primary)
|
||||
.lineLimit(1)
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(Color.benchmarkAccent.opacity(0.05))
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.benchmarkAccent.opacity(0.1), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#Preview {
|
||||
HStack(spacing: 12) {
|
||||
MetricCard(title: "Runtime", value: "2.456s", icon: "clock")
|
||||
MetricCard(title: "Speed", value: "109.8 t/s", icon: "speedometer")
|
||||
MetricCard(title: "Memory", value: "1.2 GB", icon: "memorychip")
|
||||
}
|
||||
.padding()
|
||||
}
|
|
@ -0,0 +1,241 @@
|
|||
//
|
||||
// ModelSelectionCard.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Reusable model selection card component for benchmark interface.
|
||||
* Provides dropdown menu for model selection and start/stop controls.
|
||||
*/
|
||||
struct ModelSelectionCard: View {
|
||||
@ObservedObject var viewModel: BenchmarkViewModel
|
||||
@Binding var showStopConfirmation: Bool
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 16) {
|
||||
HStack {
|
||||
Text("Select Model")
|
||||
.font(.title3)
|
||||
.fontWeight(.semibold)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Spacer()
|
||||
}
|
||||
|
||||
if viewModel.isLoading {
|
||||
HStack {
|
||||
ProgressView()
|
||||
.scaleEffect(0.8)
|
||||
Text("Loading models...")
|
||||
.font(.subheadline)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
} else {
|
||||
modelDropdownMenu
|
||||
}
|
||||
|
||||
startStopButton
|
||||
|
||||
statusMessages
|
||||
}
|
||||
.padding(20)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(Color.benchmarkCardBg)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.stroke(Color.benchmarkSuccess.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Private Views
|
||||
|
||||
private var modelDropdownMenu: some View {
|
||||
Menu {
|
||||
if viewModel.availableModels.isEmpty {
|
||||
Button("No models available") {
|
||||
// Placeholder - no action
|
||||
}
|
||||
.disabled(true)
|
||||
} else {
|
||||
ForEach(viewModel.availableModels, id: \.id) { model in
|
||||
Button(action: {
|
||||
viewModel.onModelSelected(model)
|
||||
}) {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 2) {
|
||||
Text(model.modelName)
|
||||
.font(.system(size: 14, weight: .medium))
|
||||
Text("Local")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} label: {
|
||||
HStack(spacing: 16) {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
Text(viewModel.selectedModel?.modelName ?? String(localized: "Choose your AI model"))
|
||||
.font(.system(size: 16, weight: .medium))
|
||||
.foregroundColor(viewModel.isRunning ? .secondary : (viewModel.selectedModel != nil ? .primary : .benchmarkSecondary))
|
||||
.lineLimit(1)
|
||||
|
||||
if let model = viewModel.selectedModel {
|
||||
HStack(spacing: 8) {
|
||||
HStack(spacing: 4) {
|
||||
Circle()
|
||||
.fill(Color.benchmarkSuccess)
|
||||
.frame(width: 6, height: 6)
|
||||
Text("Ready")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSuccess)
|
||||
}
|
||||
|
||||
if let size = model.cachedSize {
|
||||
Text("• \(formatBytes(size))")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Text("Tap to select a model for testing")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
Image(systemName: "chevron.down")
|
||||
.font(.system(size: 14, weight: .medium))
|
||||
.foregroundColor(viewModel.isRunning ? .secondary : .benchmarkSecondary)
|
||||
.rotationEffect(.degrees(0))
|
||||
}
|
||||
.padding(20)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(Color.benchmarkCardBg)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.stroke(
|
||||
viewModel.isRunning ?
|
||||
Color.gray.opacity(0.1) :
|
||||
(viewModel.selectedModel != nil ?
|
||||
Color.benchmarkAccent.opacity(0.3) :
|
||||
Color.gray.opacity(0.2)),
|
||||
lineWidth: 1
|
||||
)
|
||||
))
|
||||
}
|
||||
.disabled(viewModel.isRunning)
|
||||
}
|
||||
|
||||
private var startStopButton: some View {
|
||||
Button(action: {
|
||||
if viewModel.startButtonText.contains("Stop") {
|
||||
showStopConfirmation = true
|
||||
} else {
|
||||
viewModel.onStartBenchmarkTapped()
|
||||
}
|
||||
}) {
|
||||
HStack(spacing: 12) {
|
||||
ZStack {
|
||||
Circle()
|
||||
.fill(Color.white.opacity(0.2))
|
||||
.frame(width: 32, height: 32)
|
||||
|
||||
if viewModel.isRunning && viewModel.startButtonText.contains("Stop") {
|
||||
ProgressView()
|
||||
.progressViewStyle(CircularProgressViewStyle(tint: .white))
|
||||
.scaleEffect(0.7)
|
||||
} else {
|
||||
Image(systemName: viewModel.startButtonText.contains("Stop") ? "stop.fill" : "play.fill")
|
||||
.font(.system(size: 16, weight: .bold))
|
||||
.foregroundColor(.white)
|
||||
}
|
||||
}
|
||||
|
||||
Text(viewModel.startButtonText)
|
||||
.font(.system(size: 18, weight: .semibold))
|
||||
.foregroundColor(.white)
|
||||
|
||||
Spacer()
|
||||
|
||||
if !viewModel.startButtonText.contains("Stop") {
|
||||
Image(systemName: "arrow.right")
|
||||
.font(.system(size: 16, weight: .semibold))
|
||||
.foregroundColor(.white.opacity(0.8))
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding(.horizontal, 24)
|
||||
.padding(.vertical, 18)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(
|
||||
viewModel.isStartButtonEnabled ?
|
||||
(viewModel.startButtonText.contains("Stop") ?
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkError, Color.benchmarkError.opacity(0.8)],
|
||||
startPoint: .leading,
|
||||
endPoint: .trailing
|
||||
) :
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkGradientStart, Color.benchmarkGradientEnd],
|
||||
startPoint: .leading,
|
||||
endPoint: .trailing
|
||||
)) :
|
||||
LinearGradient(
|
||||
colors: [Color.gray, Color.gray.opacity(0.8)],
|
||||
startPoint: .leading,
|
||||
endPoint: .trailing
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
.disabled(!viewModel.isStartButtonEnabled || viewModel.selectedModel == nil)
|
||||
.animation(.easeInOut(duration: 0.2), value: viewModel.startButtonText)
|
||||
.animation(.easeInOut(duration: 0.2), value: viewModel.isStartButtonEnabled)
|
||||
}
|
||||
|
||||
private var statusMessages: some View {
|
||||
Group {
|
||||
if viewModel.selectedModel == nil {
|
||||
Text("Start benchmark after selecting your model")
|
||||
.font(.caption)
|
||||
.foregroundColor(.orange)
|
||||
.padding(.horizontal, 16)
|
||||
} else if viewModel.availableModels.isEmpty {
|
||||
Text("No local models found. Please download a model first.")
|
||||
.font(.caption)
|
||||
.foregroundColor(.orange)
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helper Functions
|
||||
|
||||
private func formatBytes(_ bytes: Int64) -> String {
|
||||
let formatter = ByteCountFormatter()
|
||||
formatter.allowedUnits = [.useGB, .useMB]
|
||||
formatter.countStyle = .file
|
||||
return formatter.string(fromByteCount: bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#Preview {
|
||||
ModelSelectionCard(
|
||||
viewModel: BenchmarkViewModel(),
|
||||
showStopConfirmation: .constant(false)
|
||||
)
|
||||
.padding()
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
//
|
||||
// EnhancedPerformanceMetricView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Enhanced performance metric display component.
|
||||
* Shows detailed performance metrics with gradient backgrounds, icons, and custom colors.
|
||||
*/
|
||||
struct PerformanceMetricView: View {
|
||||
let icon: String
|
||||
let title: String
|
||||
let value: String
|
||||
let subtitle: String
|
||||
let color: Color
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .center, spacing: 12) {
|
||||
ZStack {
|
||||
Circle()
|
||||
.fill(
|
||||
LinearGradient(
|
||||
colors: [color.opacity(0.2), color.opacity(0.1)],
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
)
|
||||
)
|
||||
.frame(width: 50, height: 50)
|
||||
|
||||
Image(systemName: icon)
|
||||
.font(.system(size: 25, weight: .semibold))
|
||||
.foregroundColor(color)
|
||||
}
|
||||
|
||||
VStack(alignment: .center, spacing: 2) {
|
||||
Text(title)
|
||||
.font(.subheadline)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Text(subtitle)
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
|
||||
Text(value)
|
||||
.font(.title2)
|
||||
.fontWeight(.bold)
|
||||
.foregroundColor(color)
|
||||
.multilineTextAlignment(.center)
|
||||
.lineLimit(nil)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .center)
|
||||
.padding(16)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 12)
|
||||
.fill(
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkCardBg, color.opacity(0.02)],
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
)
|
||||
)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 12)
|
||||
.stroke(color.opacity(0.2), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#Preview {
|
||||
VStack(spacing: 16) {
|
||||
HStack(spacing: 12) {
|
||||
PerformanceMetricView(
|
||||
icon: "speedometer",
|
||||
title: "Prefill Speed",
|
||||
value: "1024.5 t/s",
|
||||
subtitle: "Tokens per second",
|
||||
color: .benchmarkGradientStart
|
||||
)
|
||||
|
||||
PerformanceMetricView(
|
||||
icon: "gauge",
|
||||
title: "Decode Speed",
|
||||
value: "109.8 t/s",
|
||||
subtitle: "Generation rate",
|
||||
color: .benchmarkGradientEnd
|
||||
)
|
||||
}
|
||||
|
||||
HStack(spacing: 12) {
|
||||
PerformanceMetricView(
|
||||
icon: "memorychip",
|
||||
title: "Memory Usage",
|
||||
value: "1.2 GB",
|
||||
subtitle: "Peak memory",
|
||||
color: .benchmarkWarning
|
||||
)
|
||||
|
||||
PerformanceMetricView(
|
||||
icon: "clock",
|
||||
title: "Total Time",
|
||||
value: "2.456s",
|
||||
subtitle: "Complete duration",
|
||||
color: .benchmarkSuccess
|
||||
)
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
//
|
||||
// ProgressCard.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Reusable progress tracking card component for benchmark interface.
|
||||
* Displays test progress with detailed metrics and visual indicators.
|
||||
*/
|
||||
struct ProgressCard: View {
|
||||
let progress: BenchmarkProgress?
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 20) {
|
||||
if let progress = progress {
|
||||
VStack(alignment: .leading, spacing: 16) {
|
||||
progressHeader(progress)
|
||||
progressBar(progress)
|
||||
|
||||
if progress.progressType == .runningTest && progress.totalIterations > 0 {
|
||||
testDetails(progress)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fallbackProgress
|
||||
}
|
||||
}
|
||||
.padding(20)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(Color.benchmarkCardBg)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.stroke(Color.benchmarkSuccess.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Private Views
|
||||
|
||||
private func progressHeader(_ progress: BenchmarkProgress) -> some View {
|
||||
HStack {
|
||||
HStack(spacing: 12) {
|
||||
ZStack {
|
||||
Circle()
|
||||
.fill(
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkAccent.opacity(0.2), Color.benchmarkGradientEnd.opacity(0.1)],
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
)
|
||||
)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
Image(systemName: "chart.line.uptrend.xyaxis")
|
||||
.font(.system(size: 18, weight: .semibold))
|
||||
.foregroundColor(.benchmarkAccent)
|
||||
}
|
||||
|
||||
VStack(alignment: .leading, spacing: 2) {
|
||||
Text("Test Progress")
|
||||
.font(.title3)
|
||||
.fontWeight(.semibold)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Text("Running performance tests")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
VStack(alignment: .trailing, spacing: 2) {
|
||||
Text("\(progress.progress)%")
|
||||
.font(.title2)
|
||||
.fontWeight(.bold)
|
||||
.foregroundColor(.benchmarkAccent)
|
||||
|
||||
Text("Complete")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func progressBar(_ progress: BenchmarkProgress) -> some View {
|
||||
VStack(spacing: 8) {
|
||||
ZStack(alignment: .leading) {
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(Color.gray.opacity(0.2))
|
||||
.frame(height: 8)
|
||||
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkGradientStart, Color.benchmarkGradientEnd],
|
||||
startPoint: .leading,
|
||||
endPoint: .trailing
|
||||
)
|
||||
)
|
||||
.frame(width: CGFloat(progress.progress) / 100 * UIScreen.main.bounds.width * 0.8, height: 8)
|
||||
.animation(.easeInOut(duration: 0.3), value: progress.progress)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func testDetails(_ progress: BenchmarkProgress) -> some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
// Test iteration info
|
||||
HStack {
|
||||
Image(systemName: "repeat")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkAccent)
|
||||
|
||||
Text("Test \(progress.currentIteration) of \(progress.totalIterations)")
|
||||
.font(.subheadline)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Spacer()
|
||||
|
||||
Text("PP: \(progress.nPrompt) • TG: \(progress.nGenerate)")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 4)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.fill(Color.benchmarkAccent.opacity(0.1))
|
||||
)
|
||||
}
|
||||
|
||||
// Real-time performance metrics
|
||||
if progress.runTimeSeconds > 0 {
|
||||
VStack(spacing: 12) {
|
||||
// Timing metrics
|
||||
HStack(spacing: 12) {
|
||||
MetricCard(title: "Runtime", value: String(format: "%.3fs", progress.runTimeSeconds), icon: "clock")
|
||||
MetricCard(title: "Prefill", value: String(format: "%.3fs", progress.prefillTimeSeconds), icon: "arrow.up.circle")
|
||||
MetricCard(title: "Decode", value: String(format: "%.3fs", progress.decodeTimeSeconds), icon: "arrow.down.circle")
|
||||
}
|
||||
|
||||
// Speed metrics
|
||||
HStack(spacing: 12) {
|
||||
MetricCard(title: "Prefill Speed", value: String(format: "%.2f t/s", progress.prefillSpeed), icon: "speedometer")
|
||||
MetricCard(title: "Decode Speed", value: String(format: "%.2f t/s", progress.decodeSpeed), icon: "gauge")
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var fallbackProgress: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
Text("Progress")
|
||||
.font(.headline)
|
||||
ProgressView()
|
||||
.progressViewStyle(LinearProgressViewStyle())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#Preview {
|
||||
ProgressCard(
|
||||
progress: BenchmarkProgress(
|
||||
progress: 65,
|
||||
statusMessage: "Running benchmark...",
|
||||
progressType: .runningTest,
|
||||
currentIteration: 3,
|
||||
totalIterations: 5,
|
||||
nPrompt: 128,
|
||||
nGenerate: 256,
|
||||
runTimeSeconds: 2.456,
|
||||
prefillTimeSeconds: 0.123,
|
||||
decodeTimeSeconds: 2.333,
|
||||
prefillSpeed: 1024.5,
|
||||
decodeSpeed: 109.8
|
||||
)
|
||||
)
|
||||
.padding()
|
||||
}
|
|
@ -0,0 +1,311 @@
|
|||
//
|
||||
// ResultsCard.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Reusable results display card component for benchmark interface.
|
||||
* Shows comprehensive benchmark results with performance metrics and statistics.
|
||||
*/
|
||||
struct ResultsCard: View {
|
||||
let results: BenchmarkResults
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 20) {
|
||||
resultsHeader
|
||||
infoHeader
|
||||
performanceMetrics
|
||||
detailedStats
|
||||
}
|
||||
.padding(20)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(Color.benchmarkCardBg)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.stroke(Color.benchmarkSuccess.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Private Views
|
||||
|
||||
private var infoHeader: some View {
|
||||
|
||||
let statistics = BenchmarkResultsHelper.shared.processTestResults(results.testResults)
|
||||
|
||||
return VStack(alignment: .leading, spacing: 8) {
|
||||
Text(results.modelDisplayName)
|
||||
.font(.headline)
|
||||
Text(BenchmarkResultsHelper.shared.getDeviceInfo())
|
||||
.font(.subheadline)
|
||||
.foregroundColor(.secondary)
|
||||
|
||||
Text("Benchmark Config")
|
||||
.font(.headline)
|
||||
Text(statistics.configText)
|
||||
.font(.subheadline)
|
||||
.lineLimit(nil)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
}
|
||||
|
||||
private var resultsHeader: some View {
|
||||
HStack {
|
||||
HStack(spacing: 12) {
|
||||
ZStack {
|
||||
Circle()
|
||||
.fill(
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkSuccess.opacity(0.2), Color.benchmarkSuccess.opacity(0.1)],
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
)
|
||||
)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
Image(systemName: "chart.bar.fill")
|
||||
.font(.system(size: 18, weight: .semibold))
|
||||
.foregroundColor(.benchmarkSuccess)
|
||||
}
|
||||
|
||||
VStack(alignment: .leading, spacing: 2) {
|
||||
Text("Benchmark Results")
|
||||
.font(.title3)
|
||||
.fontWeight(.semibold)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Text("Performance analysis complete")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
Button(action: {
|
||||
shareResults()
|
||||
}) {
|
||||
VStack(alignment: .center, spacing: 2) {
|
||||
Image(systemName: "square.and.arrow.up")
|
||||
.font(.title2)
|
||||
.foregroundColor(.benchmarkSuccess)
|
||||
|
||||
Text("Share")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
.buttonStyle(PlainButtonStyle())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private var performanceMetrics: some View {
|
||||
let statistics = BenchmarkResultsHelper.shared.processTestResults(results.testResults)
|
||||
|
||||
return VStack(spacing: 16) {
|
||||
HStack(spacing: 12) {
|
||||
if let prefillStats = statistics.prefillStats {
|
||||
PerformanceMetricView(
|
||||
icon: "speedometer",
|
||||
title: "Prefill Speed",
|
||||
value: BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(prefillStats),
|
||||
subtitle: "Tokens per second",
|
||||
color: .benchmarkGradientStart
|
||||
)
|
||||
} else {
|
||||
PerformanceMetricView(
|
||||
icon: "speedometer",
|
||||
title: "Prefill Speed",
|
||||
value: "N/A",
|
||||
subtitle: "Tokens per second",
|
||||
color: .benchmarkGradientStart
|
||||
)
|
||||
}
|
||||
|
||||
if let decodeStats = statistics.decodeStats {
|
||||
PerformanceMetricView(
|
||||
icon: "gauge",
|
||||
title: "Decode Speed",
|
||||
value: BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(decodeStats),
|
||||
subtitle: "Generation rate",
|
||||
color: .benchmarkGradientEnd
|
||||
)
|
||||
} else {
|
||||
PerformanceMetricView(
|
||||
icon: "gauge",
|
||||
title: "Decode Speed",
|
||||
value: "N/A",
|
||||
subtitle: "Generation rate",
|
||||
color: .benchmarkGradientEnd
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
HStack(spacing: 12) {
|
||||
let totalMemoryKb = BenchmarkResultsHelper.shared.getTotalSystemMemoryKb()
|
||||
let memoryInfo = BenchmarkResultsHelper.shared.formatMemoryUsage(
|
||||
maxMemoryKb: results.maxMemoryKb,
|
||||
totalKb: totalMemoryKb
|
||||
)
|
||||
|
||||
PerformanceMetricView(
|
||||
icon: "memorychip",
|
||||
title: "Memory Usage",
|
||||
value: memoryInfo.valueText,
|
||||
subtitle: "Peak memory",
|
||||
color: .benchmarkWarning
|
||||
)
|
||||
|
||||
PerformanceMetricView(
|
||||
icon: "clock",
|
||||
title: "Total Tokens",
|
||||
value: "\(statistics.totalTokensProcessed)",
|
||||
subtitle: "Complete duration",
|
||||
color: .benchmarkSuccess
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var detailedStats: some View {
|
||||
return VStack(alignment: .leading, spacing: 12) {
|
||||
VStack(spacing: 8) {
|
||||
HStack {
|
||||
Text("Completed")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
Spacer()
|
||||
Text(results.timestamp)
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
|
||||
HStack {
|
||||
Text("Powered By MNN")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
Spacer()
|
||||
Text(verbatim: "https://github.com/alibaba/MNN")
|
||||
.font(.caption)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helper Functions
|
||||
|
||||
/// Formats byte count into human-readable string
|
||||
private func formatBytes(_ bytes: Int64) -> String {
|
||||
let formatter = ByteCountFormatter()
|
||||
formatter.allowedUnits = [.useKB, .useMB, .useGB]
|
||||
formatter.countStyle = .file
|
||||
return formatter.string(fromByteCount: bytes)
|
||||
}
|
||||
|
||||
/// Initiates sharing of benchmark results through system share sheet
|
||||
private func shareResults() {
|
||||
let viewToRender = self.body.frame(width: 390) // Adjust width as needed
|
||||
if let image = viewToRender.snapshot() {
|
||||
presentShareSheet(activityItems: [image, formatResultsForSharing()])
|
||||
} else {
|
||||
presentShareSheet(activityItems: [formatResultsForSharing()])
|
||||
}
|
||||
}
|
||||
|
||||
private func presentShareSheet(activityItems: [Any]) {
|
||||
let activityViewController = UIActivityViewController(activityItems: activityItems, applicationActivities: nil)
|
||||
|
||||
if let windowScene = UIApplication.shared.connectedScenes.first as? UIWindowScene,
|
||||
let window = windowScene.windows.first,
|
||||
let rootViewController = window.rootViewController {
|
||||
|
||||
if let popover = activityViewController.popoverPresentationController {
|
||||
popover.sourceView = window
|
||||
popover.sourceRect = CGRect(x: window.bounds.midX, y: window.bounds.midY, width: 0, height: 0)
|
||||
popover.permittedArrowDirections = []
|
||||
}
|
||||
|
||||
rootViewController.present(activityViewController, animated: true)
|
||||
}
|
||||
}
|
||||
|
||||
/// Formats benchmark results into shareable text format with performance metrics and hashtags
|
||||
private func formatResultsForSharing() -> String {
|
||||
let statistics = BenchmarkResultsHelper.shared.processTestResults(results.testResults)
|
||||
let deviceInfo = BenchmarkResultsHelper.shared.getDeviceInfo()
|
||||
|
||||
var shareText = """
|
||||
📱 MNN LLM Benchmark Results
|
||||
|
||||
🤖 Model: \(results.modelDisplayName)
|
||||
📱 \(deviceInfo)
|
||||
📅 Completed: \(results.timestamp)
|
||||
|
||||
📊 Configuration:
|
||||
\(statistics.configText)
|
||||
|
||||
⚡️ Performance Results:
|
||||
"""
|
||||
|
||||
if let prefillStats = statistics.prefillStats {
|
||||
shareText += "\n🔄 Prompt Processing: \(BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(prefillStats))"
|
||||
}
|
||||
|
||||
if let decodeStats = statistics.decodeStats {
|
||||
shareText += "\n⚡️ Token Generation: \(BenchmarkResultsHelper.shared.formatSpeedStatisticsLine(decodeStats))"
|
||||
}
|
||||
|
||||
let totalMemoryKb = BenchmarkResultsHelper.shared.getTotalSystemMemoryKb()
|
||||
let memoryInfo = BenchmarkResultsHelper.shared.formatMemoryUsage(
|
||||
maxMemoryKb: results.maxMemoryKb,
|
||||
totalKb: totalMemoryKb
|
||||
)
|
||||
shareText += "\n💾 Peak Memory: \(memoryInfo.valueText) (\(memoryInfo.labelText))"
|
||||
|
||||
shareText += "\n\n📈 Summary:"
|
||||
shareText += "\n• Total Tokens Processed: \(statistics.totalTokensProcessed)"
|
||||
shareText += "\n• Number of Tests: \(statistics.totalTests)"
|
||||
|
||||
shareText += "\n\n#MNNLLMBenchmark #AIPerformance #MobileAI"
|
||||
|
||||
return shareText
|
||||
}
|
||||
}
|
||||
|
||||
extension View {
|
||||
func snapshot() -> UIImage? {
|
||||
let controller = UIHostingController(rootView: self)
|
||||
let view = controller.view
|
||||
|
||||
let targetSize = controller.view.intrinsicContentSize
|
||||
view?.bounds = CGRect(origin: .zero, size: targetSize)
|
||||
view?.backgroundColor = .clear
|
||||
|
||||
let renderer = UIGraphicsImageRenderer(size: targetSize)
|
||||
|
||||
return renderer.image { _ in
|
||||
view?.drawHierarchy(in: controller.view.bounds, afterScreenUpdates: true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#Preview {
|
||||
ResultsCard(
|
||||
results: BenchmarkResults(
|
||||
modelDisplayName: "Qwen2.5-1.5B-Instruct",
|
||||
maxMemoryKb: 1200000, // 1.2 GB in KB
|
||||
testResults: [],
|
||||
timestamp: "2025-01-21 14:30:25"
|
||||
)
|
||||
)
|
||||
.padding()
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
//
|
||||
// StatusCard.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Reusable status display card component for benchmark interface.
|
||||
* Shows status messages and updates to provide user feedback.
|
||||
*/
|
||||
struct StatusCard: View {
|
||||
let statusMessage: String
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 16) {
|
||||
ZStack {
|
||||
Circle()
|
||||
.fill(
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkWarning.opacity(0.2), Color.benchmarkWarning.opacity(0.1)],
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
)
|
||||
)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
Image(systemName: "info.circle")
|
||||
.font(.system(size: 18, weight: .semibold))
|
||||
.foregroundColor(.benchmarkWarning)
|
||||
}
|
||||
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Status Update")
|
||||
.font(.subheadline)
|
||||
.fontWeight(.semibold)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Text(statusMessage)
|
||||
.font(.subheadline)
|
||||
.foregroundColor(.benchmarkSecondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
}
|
||||
.padding(20)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(Color.benchmarkCardBg)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.stroke(Color.benchmarkWarning.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#Preview {
|
||||
StatusCard(statusMessage: "Initializing benchmark test environment...")
|
||||
.padding()
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
//
|
||||
// BenchmarkView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/21.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/**
|
||||
* Main benchmark view that provides interface for running performance tests on ML models.
|
||||
* Features include model selection, progress tracking, and results visualization.
|
||||
*/
|
||||
struct BenchmarkView: View {
|
||||
@StateObject private var viewModel = BenchmarkViewModel()
|
||||
@State private var showStopConfirmation = false
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
ScrollView {
|
||||
VStack(spacing: 24) {
|
||||
// Model Selection Section
|
||||
ModelSelectionCard(
|
||||
viewModel: viewModel,
|
||||
showStopConfirmation: $showStopConfirmation
|
||||
)
|
||||
|
||||
// Progress Section
|
||||
if viewModel.showProgressBar {
|
||||
ProgressCard(progress: viewModel.currentProgress)
|
||||
.transition(.asymmetric(
|
||||
insertion: .scale.combined(with: .opacity),
|
||||
removal: .opacity
|
||||
))
|
||||
}
|
||||
|
||||
// Status Section
|
||||
if !viewModel.statusMessage.isEmpty {
|
||||
StatusCard(statusMessage: viewModel.statusMessage)
|
||||
.transition(.slide)
|
||||
}
|
||||
|
||||
// Results Section
|
||||
if viewModel.showResults, let results = viewModel.benchmarkResults {
|
||||
ResultsCard(results: results)
|
||||
.transition(.asymmetric(
|
||||
insertion: .move(edge: .bottom).combined(with: .opacity),
|
||||
removal: .opacity
|
||||
))
|
||||
}
|
||||
|
||||
Spacer(minLength: 20)
|
||||
}
|
||||
.padding(.horizontal, 20)
|
||||
.padding(.vertical, 16)
|
||||
}
|
||||
}
|
||||
.alert("Stop Benchmark", isPresented: $showStopConfirmation) {
|
||||
Button("Yes", role: .destructive) {
|
||||
viewModel.onStopBenchmarkTapped()
|
||||
}
|
||||
Button("No", role: .cancel) { }
|
||||
} message: {
|
||||
Text("Are you sure you want to stop the benchmark test?")
|
||||
}
|
||||
.alert("Error", isPresented: $viewModel.showError) {
|
||||
Button("OK") { }
|
||||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
.onReceive(viewModel.$isRunning) { isRunning in
|
||||
if isRunning && viewModel.startButtonText.contains("Stop") {
|
||||
showStopConfirmation = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
//
|
||||
// CommonToolbarView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/07/18.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct CommonToolbarView: ToolbarContent {
|
||||
@Binding var showHistory: Bool
|
||||
@Binding var showHistoryButton: Bool
|
||||
|
||||
var body: some ToolbarContent {
|
||||
ToolbarItem(placement: .navigationBarLeading) {
|
||||
if showHistoryButton {
|
||||
Button(action: {
|
||||
showHistory = true
|
||||
showHistoryButton = false
|
||||
}) {
|
||||
Image(systemName: "sidebar.left")
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fit)
|
||||
.frame(width: 20, height: 20)
|
||||
.foregroundColor(.black)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .navigationBarTrailing) {
|
||||
Button(action: {
|
||||
if let url = URL(string: "https://github.com/alibaba/MNN") {
|
||||
UIApplication.shared.open(url)
|
||||
}
|
||||
}) {
|
||||
Image(systemName: "star")
|
||||
.resizable()
|
||||
.aspectRatio(contentMode: .fit)
|
||||
.frame(width: 20, height: 20)
|
||||
.foregroundColor(.black)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
//
|
||||
// MNNLLMiOSApp.swift
|
||||
// LocalModelListView
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/06/20.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct LocalModelListView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
||||
var body: some View {
|
||||
List {
|
||||
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.id) { model in
|
||||
Button(action: {
|
||||
viewModel.selectModel(model)
|
||||
}) {
|
||||
LocalModelRowView(model: model)
|
||||
}
|
||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
|
||||
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
||||
SwipeActionsView(model: model, viewModel: viewModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
.listStyle(.plain)
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
.alert("Error", isPresented: $viewModel.showError) {
|
||||
Button("OK", role: .cancel) {}
|
||||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
//
|
||||
// LocalModelRowView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/6/26.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct LocalModelRowView: View {
|
||||
|
||||
let model: ModelInfo
|
||||
|
||||
private var localizedTags: [String] {
|
||||
model.localizedTags
|
||||
}
|
||||
|
||||
private var formattedSize: String {
|
||||
model.formattedSize
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .center) {
|
||||
|
||||
ModelIconView(modelId: model.id)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
Text(model.modelName)
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
.lineLimit(1)
|
||||
|
||||
if !localizedTags.isEmpty {
|
||||
TagsView(tags: localizedTags)
|
||||
}
|
||||
|
||||
HStack {
|
||||
HStack(alignment: .center, spacing: 2) {
|
||||
Image(systemName: "folder")
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
.frame(width: 20, height: 20)
|
||||
|
||||
Text(formattedSize)
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if let lastUsedAt = model.lastUsedAt {
|
||||
Text("\(lastUsedAt.formatAgo())")
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
//
|
||||
// MNNLLMiOSApp.swift
|
||||
// MainTabView
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/06/20.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
// MainTabView is the primary view of the app, containing the tab bar and navigation for main sections.
|
||||
struct MainTabView: View {
|
||||
// MARK: - State Properties
|
||||
|
||||
@State private var showHistory = false
|
||||
@State private var selectedHistory: ChatHistory? = nil
|
||||
@State private var histories: [ChatHistory] = ChatHistoryManager.shared.getAllHistory()
|
||||
@State private var showHistoryButton = true
|
||||
@State private var showSettings = false
|
||||
@State private var showWebView = false
|
||||
@State private var webViewURL: URL?
|
||||
@State private var navigateToSettings = false
|
||||
@StateObject private var modelListViewModel = ModelListViewModel()
|
||||
@State private var selectedTab: Int = 0
|
||||
|
||||
private var titles: [String] {
|
||||
[
|
||||
NSLocalizedString("Local Model", comment: "本地模型标签"),
|
||||
NSLocalizedString("Model Market", comment: "模型市场标签"),
|
||||
NSLocalizedString("Benchmark", comment: "基准测试标签")
|
||||
]
|
||||
}
|
||||
|
||||
// MARK: - Body
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
// Main TabView for navigation between Local Model, Model Market, and Benchmark
|
||||
TabView(selection: $selectedTab) {
|
||||
NavigationView {
|
||||
LocalModelListView(viewModel: modelListViewModel)
|
||||
.navigationTitle(titles[0])
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.navigationBarHidden(false)
|
||||
.onAppear {
|
||||
setupNavigationBarAppearance()
|
||||
}
|
||||
.toolbar {
|
||||
CommonToolbarView(
|
||||
showHistory: $showHistory,
|
||||
showHistoryButton: $showHistoryButton,
|
||||
)
|
||||
}
|
||||
.background(
|
||||
ZStack {
|
||||
NavigationLink(destination: chatDestination, isActive: chatIsActiveBinding) { EmptyView() }
|
||||
NavigationLink(destination: SettingsView(), isActive: $navigateToSettings) { EmptyView() }
|
||||
}
|
||||
)
|
||||
// Hide TabBar when entering chat or settings view
|
||||
.toolbar((chatIsActiveBinding.wrappedValue || navigateToSettings) ? .hidden : .visible, for: .tabBar)
|
||||
}
|
||||
.tabItem {
|
||||
Image(systemName: "house.fill")
|
||||
Text(titles[0])
|
||||
}
|
||||
.tag(0)
|
||||
|
||||
NavigationView {
|
||||
ModelListView(viewModel: modelListViewModel)
|
||||
.navigationTitle(titles[1])
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.navigationBarHidden(false)
|
||||
.onAppear {
|
||||
setupNavigationBarAppearance()
|
||||
}
|
||||
.toolbar {
|
||||
CommonToolbarView(
|
||||
showHistory: $showHistory,
|
||||
showHistoryButton: $showHistoryButton,
|
||||
)
|
||||
}
|
||||
.background(
|
||||
ZStack {
|
||||
NavigationLink(destination: chatDestination, isActive: chatIsActiveBinding) { EmptyView() }
|
||||
NavigationLink(destination: SettingsView(), isActive: $navigateToSettings) { EmptyView() }
|
||||
}
|
||||
)
|
||||
}
|
||||
.tabItem {
|
||||
Image(systemName: "doc.text.fill")
|
||||
Text(titles[1])
|
||||
}
|
||||
.tag(1)
|
||||
|
||||
NavigationView {
|
||||
BenchmarkView()
|
||||
.navigationTitle(titles[2])
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.navigationBarHidden(false)
|
||||
.onAppear {
|
||||
setupNavigationBarAppearance()
|
||||
}
|
||||
.toolbar {
|
||||
CommonToolbarView(
|
||||
showHistory: $showHistory,
|
||||
showHistoryButton: $showHistoryButton,
|
||||
)
|
||||
}
|
||||
.background(
|
||||
ZStack {
|
||||
NavigationLink(destination: chatDestination, isActive: chatIsActiveBinding) { EmptyView() }
|
||||
NavigationLink(destination: SettingsView(), isActive: $navigateToSettings) { EmptyView() }
|
||||
}
|
||||
)
|
||||
}
|
||||
.tabItem {
|
||||
Image(systemName: "clock.fill")
|
||||
Text(titles[2])
|
||||
}
|
||||
.tag(2)
|
||||
}
|
||||
.onAppear {
|
||||
setupTabBarAppearance()
|
||||
}
|
||||
.tint(.black)
|
||||
|
||||
// Overlay for dimming the background when history is shown
|
||||
if showHistory {
|
||||
Color.black.opacity(0.5)
|
||||
.edgesIgnoringSafeArea(.all)
|
||||
.onTapGesture {
|
||||
withAnimation(.easeInOut(duration: 0.2)) {
|
||||
showHistory = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Side menu for displaying chat history
|
||||
SideMenuView(isOpen: $showHistory,
|
||||
selectedHistory: $selectedHistory,
|
||||
histories: $histories,
|
||||
navigateToMainSettings: $navigateToSettings)
|
||||
.edgesIgnoringSafeArea(.all)
|
||||
}
|
||||
.onChange(of: showHistory) { oldValue, newValue in
|
||||
if !newValue {
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
|
||||
withAnimation {
|
||||
showHistoryButton = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.sheet(isPresented: $showWebView) {
|
||||
if let url = webViewURL {
|
||||
WebView(url: url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - View Builders
|
||||
|
||||
/// Destination view for chat, either from a new model or a history item.
|
||||
@ViewBuilder
|
||||
private var chatDestination: some View {
|
||||
if let model = modelListViewModel.selectedModel {
|
||||
LLMChatView(modelInfo: model)
|
||||
.navigationBarHidden(false)
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
} else if let history = selectedHistory {
|
||||
let modelInfo = ModelInfo(modelId: history.modelId, isDownloaded: true)
|
||||
LLMChatView(modelInfo: modelInfo, history: history)
|
||||
.navigationBarHidden(false)
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
} else {
|
||||
EmptyView()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Bindings
|
||||
|
||||
/// Binding to control the activation of the chat view.
|
||||
private var chatIsActiveBinding: Binding<Bool> {
|
||||
Binding<Bool>(
|
||||
get: {
|
||||
return modelListViewModel.selectedModel != nil || selectedHistory != nil
|
||||
},
|
||||
set: { isActive in
|
||||
if !isActive {
|
||||
// Record usage when returning from chat
|
||||
if let model = modelListViewModel.selectedModel {
|
||||
modelListViewModel.recordModelUsage(modelName: model.modelName)
|
||||
}
|
||||
|
||||
// Clear selections
|
||||
modelListViewModel.selectedModel = nil
|
||||
selectedHistory = nil
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Private Methods
|
||||
|
||||
/// Configures the appearance of the navigation bar.
|
||||
private func setupNavigationBarAppearance() {
|
||||
let appearance = UINavigationBarAppearance()
|
||||
appearance.configureWithOpaqueBackground()
|
||||
appearance.backgroundColor = .white
|
||||
appearance.shadowColor = .clear
|
||||
|
||||
UINavigationBar.appearance().standardAppearance = appearance
|
||||
UINavigationBar.appearance().compactAppearance = appearance
|
||||
UINavigationBar.appearance().scrollEdgeAppearance = appearance
|
||||
}
|
||||
|
||||
/// Configures the appearance of the tab bar.
|
||||
private func setupTabBarAppearance() {
|
||||
let appearance = UITabBarAppearance()
|
||||
appearance.configureWithOpaqueBackground()
|
||||
|
||||
let selectedColor = UIColor(Color.primaryPurple)
|
||||
|
||||
appearance.stackedLayoutAppearance.selected.iconColor = selectedColor
|
||||
appearance.stackedLayoutAppearance.selected.titleTextAttributes = [.foregroundColor: selectedColor]
|
||||
|
||||
UITabBar.appearance().standardAppearance = appearance
|
||||
UITabBar.appearance().scrollEdgeAppearance = appearance
|
||||
}
|
||||
}
|
|
@ -0,0 +1,203 @@
|
|||
//
|
||||
// TBModelInfo.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Hub
|
||||
import Foundation
|
||||
|
||||
struct ModelInfo: Codable {
|
||||
// MARK: - Properties
|
||||
let modelName: String
|
||||
let tags: [String]
|
||||
let categories: [String]?
|
||||
let size_gb: Double?
|
||||
let vendor: String?
|
||||
let sources: [String: String]?
|
||||
let tagTranslations: [String: [String]]?
|
||||
|
||||
// Runtime properties
|
||||
var isDownloaded: Bool = false
|
||||
var lastUsedAt: Date?
|
||||
var cachedSize: Int64? = nil
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
init(modelName: String = "",
|
||||
tags: [String] = [],
|
||||
categories: [String]? = nil,
|
||||
size_gb: Double? = nil,
|
||||
vendor: String? = nil,
|
||||
sources: [String: String]? = nil,
|
||||
tagTranslations: [String: [String]]? = nil,
|
||||
isDownloaded: Bool = false,
|
||||
lastUsedAt: Date? = nil,
|
||||
cachedSize: Int64? = nil) {
|
||||
|
||||
self.modelName = modelName
|
||||
self.tags = tags
|
||||
self.categories = categories
|
||||
self.size_gb = size_gb
|
||||
self.vendor = vendor
|
||||
self.sources = sources
|
||||
self.tagTranslations = tagTranslations
|
||||
self.isDownloaded = isDownloaded
|
||||
self.lastUsedAt = lastUsedAt
|
||||
self.cachedSize = cachedSize
|
||||
}
|
||||
|
||||
init(modelId: String, isDownloaded: Bool = true) {
|
||||
let modelName = modelId.components(separatedBy: "/").last ?? modelId
|
||||
|
||||
self.init(
|
||||
modelName: modelName,
|
||||
tags: [],
|
||||
sources: ["huggingface": modelId],
|
||||
isDownloaded: isDownloaded
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Model Identity & Localization
|
||||
|
||||
var id: String {
|
||||
guard let sources = sources else {
|
||||
return "taobao-mnn/\(modelName)"
|
||||
}
|
||||
|
||||
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
|
||||
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
||||
}
|
||||
|
||||
var localizedTags: [String] {
|
||||
let currentLanguage = LanguageManager.shared.currentLanguage
|
||||
let isChineseLanguage = currentLanguage == "简体中文"
|
||||
|
||||
if isChineseLanguage, let translations = tagTranslations {
|
||||
let languageCode = "zh-Hans"
|
||||
return translations[languageCode] ?? tags
|
||||
} else {
|
||||
return tags
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - File System & Path Management
|
||||
|
||||
var localPath: String {
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
|
||||
}
|
||||
|
||||
// MARK: - Size Calculation & Formatting
|
||||
|
||||
var formattedSize: String {
|
||||
if let cached = cachedSize {
|
||||
return FileOperationManager.shared.formatBytes(cached)
|
||||
} else if isDownloaded {
|
||||
return FileOperationManager.shared.formatLocalDirectorySize(at: localPath)
|
||||
} else if let sizeGb = size_gb {
|
||||
return String(format: "%.1f GB", sizeGb)
|
||||
} else {
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates and caches the local directory size
|
||||
/// - Returns: The formatted size string and updates cachedSize property
|
||||
mutating func calculateAndCacheSize() -> String {
|
||||
if let cached = cachedSize {
|
||||
return FileOperationManager.shared.formatBytes(cached)
|
||||
}
|
||||
|
||||
if isDownloaded {
|
||||
do {
|
||||
let sizeInBytes = try FileOperationManager.shared.calculateDirectorySize(at: localPath)
|
||||
self.cachedSize = sizeInBytes
|
||||
return FileOperationManager.shared.formatBytes(sizeInBytes)
|
||||
} catch {
|
||||
print("Error calculating directory size: \(error)")
|
||||
return "Unknown"
|
||||
}
|
||||
} else if let sizeGb = size_gb {
|
||||
return String(format: "%.1f GB", sizeGb)
|
||||
} else {
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Remote Size Calculation
|
||||
|
||||
func fetchRemoteSize() async -> Int64? {
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
|
||||
do {
|
||||
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
|
||||
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
|
||||
return totalSize
|
||||
} catch {
|
||||
print("Error fetching remote size for \(id): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
||||
let url = try buildURL(
|
||||
repoPath: repoPath,
|
||||
path: "/repo/files",
|
||||
queryItems: [
|
||||
URLQueryItem(name: "Root", value: root),
|
||||
URLQueryItem(name: "Revision", value: revision)
|
||||
]
|
||||
)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(from: url)
|
||||
try validateResponse(response)
|
||||
|
||||
let modelResponse = try JSONDecoder().decode(ModelResponse.self, from: data)
|
||||
return modelResponse.data.files
|
||||
}
|
||||
|
||||
private func calculateTotalSize(files: [ModelFile], repoPath: String) async throws -> Int64 {
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
for file in files {
|
||||
if file.type == "tree" {
|
||||
let subFiles = try await fetchFileList(repoPath: repoPath, root: file.path, revision: "")
|
||||
totalSize += try await calculateTotalSize(files: subFiles, repoPath: repoPath)
|
||||
} else if file.type == "blob" {
|
||||
totalSize += Int64(file.size)
|
||||
}
|
||||
}
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
// MARK: - Network Utilities
|
||||
|
||||
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
||||
var components = URLComponents()
|
||||
components.scheme = "https"
|
||||
components.host = "modelscope.cn"
|
||||
components.path = "/api/v1/models/\(repoPath)\(path)"
|
||||
components.queryItems = queryItems
|
||||
|
||||
guard let url = components.url else {
|
||||
throw ModelScopeError.invalidURL
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
private func validateResponse(_ response: URLResponse) throws {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200...299).contains(httpResponse.statusCode) else {
|
||||
throw ModelScopeError.invalidResponse
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Codable
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case modelName, tags, categories, size_gb, vendor, sources, tagTranslations, cachedSize
|
||||
}
|
||||
}
|
|
@ -0,0 +1,361 @@
|
|||
//
|
||||
// ModelListViewModel.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import SwiftUI
|
||||
|
||||
class ModelListViewModel: ObservableObject {
|
||||
// MARK: - Published Properties
|
||||
@Published var models: [ModelInfo] = []
|
||||
@Published var searchText = ""
|
||||
@Published var quickFilterTags: [String] = []
|
||||
@Published var selectedModel: ModelInfo?
|
||||
@Published var showError = false
|
||||
@Published var errorMessage = ""
|
||||
|
||||
// Download state
|
||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
||||
@Published private(set) var currentlyDownloading: String?
|
||||
|
||||
// MARK: - Private Properties
|
||||
private let modelClient = ModelClient()
|
||||
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
|
||||
|
||||
// MARK: - Model Data Access
|
||||
|
||||
public var pinnedModelIds: [String] {
|
||||
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
|
||||
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
|
||||
}
|
||||
|
||||
var allTags: [String] {
|
||||
Array(Set(models.flatMap { $0.tags }))
|
||||
}
|
||||
|
||||
var allCategories: [String] {
|
||||
Array(Set(models.compactMap { $0.categories }.flatMap { $0 }))
|
||||
}
|
||||
|
||||
var allVendors: [String] {
|
||||
Array(Set(models.compactMap { $0.vendor }))
|
||||
}
|
||||
|
||||
var filteredModels: [ModelInfo] {
|
||||
let filtered = searchText.isEmpty ? models : models.filter { model in
|
||||
model.id.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||
}
|
||||
|
||||
let downloaded = filtered.filter { $0.isDownloaded }
|
||||
let notDownloaded = filtered.filter { !$0.isDownloaded }
|
||||
|
||||
return downloaded + notDownloaded
|
||||
}
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
init() {
|
||||
Task { @MainActor in
|
||||
await fetchModels()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Data Management
|
||||
|
||||
@MainActor
|
||||
func fetchModels() async {
|
||||
do {
|
||||
let info = try await modelClient.getModelInfo()
|
||||
|
||||
self.quickFilterTags = info.quickFilterTags ?? []
|
||||
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
|
||||
|
||||
var fetchedModels = info.models
|
||||
|
||||
filterDiffusionModels(fetchedModels: &fetchedModels)
|
||||
loadCachedSizes(for: &fetchedModels)
|
||||
sortModels(fetchedModels: &fetchedModels)
|
||||
self.models = fetchedModels
|
||||
|
||||
// Asynchronously fetch size info for both downloaded and undownloaded models
|
||||
Task {
|
||||
await fetchModelSizes(for: fetchedModels)
|
||||
}
|
||||
|
||||
} catch {
|
||||
showError = true
|
||||
errorMessage = "Error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
|
||||
private func loadCachedSizes(for models: inout [ModelInfo]) {
|
||||
for i in 0..<models.count {
|
||||
if let cachedSize = ModelStorageManager.shared.getCachedSize(for: models[i].modelName) {
|
||||
models[i].cachedSize = cachedSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchModelSizes(for models: [ModelInfo]) async {
|
||||
await withTaskGroup(of: Void.self) { group in
|
||||
for (_, model) in models.enumerated() {
|
||||
// Handle undownloaded models - fetch remote size
|
||||
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
|
||||
group.addTask {
|
||||
if let size = await model.fetchRemoteSize() {
|
||||
await MainActor.run {
|
||||
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[modelIndex].cachedSize = size
|
||||
ModelStorageManager.shared.setCachedSize(size, for: model.modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle downloaded models - calculate and cache local directory size
|
||||
if model.isDownloaded && model.cachedSize == nil {
|
||||
group.addTask {
|
||||
do {
|
||||
let localSize = try FileOperationManager.shared.calculateDirectorySize(at: model.localPath)
|
||||
await MainActor.run {
|
||||
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[modelIndex].cachedSize = localSize
|
||||
ModelStorageManager.shared.setCachedSize(localSize, for: model.modelName)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("Error calculating local directory size for \(model.modelName): \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func filterDiffusionModels(fetchedModels: inout [ModelInfo]) {
|
||||
let hasDiffusionModels = fetchedModels.contains {
|
||||
$0.modelName.lowercased().contains("diffusion")
|
||||
}
|
||||
|
||||
if hasDiffusionModels {
|
||||
fetchedModels = fetchedModels.filter { model in
|
||||
let name = model.modelName.lowercased()
|
||||
let tags = model.tags.map { $0.lowercased() }
|
||||
|
||||
// Only show GPU diffusion models
|
||||
if name.contains("diffusion") {
|
||||
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..<fetchedModels.count {
|
||||
let model = fetchedModels[i]
|
||||
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
|
||||
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
|
||||
}
|
||||
}
|
||||
|
||||
private func sortModels(fetchedModels: inout [ModelInfo]) {
|
||||
let pinned = pinnedModelIds
|
||||
|
||||
fetchedModels.sort { (model1, model2) -> Bool in
|
||||
let isPinned1 = pinned.contains(model1.id)
|
||||
let isPinned2 = pinned.contains(model2.id)
|
||||
let isDownloading1 = currentlyDownloading == model1.id
|
||||
let isDownloading2 = currentlyDownloading == model2.id
|
||||
|
||||
// 1. Currently downloading models have highest priority
|
||||
if isDownloading1 != isDownloading2 {
|
||||
return isDownloading1
|
||||
}
|
||||
|
||||
// 2. Pinned models have second priority
|
||||
if isPinned1 != isPinned2 {
|
||||
return isPinned1
|
||||
}
|
||||
|
||||
// 3. If both are pinned, sort by pin time
|
||||
if isPinned1 && isPinned2 {
|
||||
let index1 = pinned.firstIndex(of: model1.id)!
|
||||
let index2 = pinned.firstIndex(of: model2.id)!
|
||||
return index1 > index2 // Pinned later comes first
|
||||
}
|
||||
|
||||
// 4. Non-pinned models sorted by download status
|
||||
if model1.isDownloaded != model2.isDownloaded {
|
||||
return model1.isDownloaded
|
||||
}
|
||||
|
||||
// 5. If both downloaded, sort by last used time
|
||||
if model1.isDownloaded {
|
||||
let date1 = model1.lastUsedAt ?? .distantPast
|
||||
let date2 = model2.lastUsedAt ?? .distantPast
|
||||
return date1 > date2
|
||||
}
|
||||
|
||||
return false // Keep original order for not-downloaded
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Selection & Usage
|
||||
|
||||
@MainActor
|
||||
func selectModel(_ model: ModelInfo) {
|
||||
if model.isDownloaded {
|
||||
selectedModel = model
|
||||
} else {
|
||||
Task {
|
||||
await downloadModel(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func recordModelUsage(modelName: String) {
|
||||
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
||||
Task { @MainActor in
|
||||
if let index = self.models.firstIndex(where: { $0.modelName == modelName }) {
|
||||
self.models[index].lastUsedAt = Date()
|
||||
self.sortModels(fetchedModels: &self.models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Download Management
|
||||
|
||||
func downloadModel(_ model: ModelInfo) async {
|
||||
await MainActor.run {
|
||||
guard currentlyDownloading == nil else { return }
|
||||
currentlyDownloading = model.id
|
||||
downloadProgress[model.id] = 0
|
||||
}
|
||||
|
||||
do {
|
||||
try await modelClient.downloadModel(model: model) { progress in
|
||||
Task { @MainActor in
|
||||
self.downloadProgress[model.id] = progress
|
||||
}
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
if let index = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[index].isDownloaded = true
|
||||
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate and cache size for newly downloaded model
|
||||
do {
|
||||
let localSize = try FileOperationManager.shared.calculateDirectorySize(at: model.localPath)
|
||||
await MainActor.run {
|
||||
if let index = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[index].cachedSize = localSize
|
||||
ModelStorageManager.shared.setCachedSize(localSize, for: model.modelName)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("Error calculating size for newly downloaded model \(model.modelName): \(error)")
|
||||
}
|
||||
|
||||
} catch {
|
||||
await MainActor.run {
|
||||
if case ModelScopeError.downloadCancelled = error {
|
||||
print("Download was cancelled")
|
||||
} else {
|
||||
self.showError = true
|
||||
self.errorMessage = "Failed to download model: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.currentlyDownloading = nil
|
||||
self.downloadProgress.removeValue(forKey: model.id)
|
||||
}
|
||||
}
|
||||
|
||||
func cancelDownload() async {
|
||||
let modelId = await MainActor.run { currentlyDownloading }
|
||||
|
||||
if let modelId = modelId {
|
||||
await modelClient.cancelDownload()
|
||||
|
||||
await MainActor.run {
|
||||
self.downloadProgress.removeValue(forKey: modelId)
|
||||
self.currentlyDownloading = nil
|
||||
}
|
||||
|
||||
print("Download cancelled for model: \(modelId)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Pin Management
|
||||
|
||||
@MainActor
|
||||
func pinModel(_ model: ModelInfo) {
|
||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||
let pinned = models.remove(at: index)
|
||||
models.insert(pinned, at: 0)
|
||||
|
||||
var pinnedIds = pinnedModelIds
|
||||
if let existingIndex = pinnedIds.firstIndex(of: model.id) {
|
||||
pinnedIds.remove(at: existingIndex)
|
||||
}
|
||||
pinnedIds.insert(model.id, at: 0)
|
||||
pinnedModelIds = pinnedIds
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func unpinModel(_ model: ModelInfo) {
|
||||
var pinnedIds = pinnedModelIds
|
||||
if let index = pinnedIds.firstIndex(of: model.id) {
|
||||
pinnedIds.remove(at: index)
|
||||
pinnedModelIds = pinnedIds
|
||||
|
||||
// Re-sort models after unpinning
|
||||
sortModels(fetchedModels: &models)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Deletion
|
||||
|
||||
func deleteModel(_ model: ModelInfo) async {
|
||||
guard model.isDownloaded else { return }
|
||||
|
||||
do {
|
||||
// Delete local files
|
||||
let fileManager = FileManager.default
|
||||
let modelPath = model.localPath
|
||||
|
||||
if fileManager.fileExists(atPath: modelPath) {
|
||||
try fileManager.removeItem(atPath: modelPath)
|
||||
}
|
||||
|
||||
// Update model state
|
||||
await MainActor.run {
|
||||
if let index = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[index].isDownloaded = false
|
||||
self.models[index].cachedSize = nil
|
||||
ModelStorageManager.shared.markModelAsNotDownloaded(model.modelName)
|
||||
}
|
||||
|
||||
// Re-sort models after deletion
|
||||
self.sortModels(fetchedModels: &self.models)
|
||||
}
|
||||
|
||||
} catch {
|
||||
await MainActor.run {
|
||||
self.showError = true
|
||||
self.errorMessage = "Failed to delete model: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
//
|
||||
// TBDataResponse.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/9.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
struct TBDataResponse: Codable {
|
||||
let tagTranslations: [String: String]
|
||||
let quickFilterTags: [String]?
|
||||
let models: [ModelInfo]
|
||||
let metadata: Metadata?
|
||||
|
||||
struct Metadata: Codable {
|
||||
let version: String
|
||||
let lastUpdated: String
|
||||
let schemaVersion: String
|
||||
let totalModels: Int
|
||||
let supportedPlatforms: [String]
|
||||
let minAppVersion: String
|
||||
}
|
||||
}
|
|
@ -0,0 +1,196 @@
|
|||
//
|
||||
// ModelClient.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Hub
|
||||
import Foundation
|
||||
|
||||
class ModelClient {
|
||||
private let maxRetries = 5
|
||||
|
||||
private let baseMirrorURL = "https://hf-mirror.com"
|
||||
private let baseURL = "https://huggingface.co"
|
||||
private let AliCDNURL = "https://meta.alicdn.com/data/mnn/apis/model_market.json"
|
||||
|
||||
// Debug flag to use local mock data instead of network API
|
||||
private let useLocalMockData = false
|
||||
|
||||
private var currentDownloadManager: ModelScopeDownloadManager?
|
||||
|
||||
private lazy var baseURLString: String = {
|
||||
switch ModelSourceManager.shared.selectedSource {
|
||||
case .huggingFace:
|
||||
return baseURL
|
||||
default:
|
||||
return baseMirrorURL
|
||||
}
|
||||
}()
|
||||
|
||||
init() {}
|
||||
|
||||
func getModelInfo() async throws -> TBDataResponse {
|
||||
if useLocalMockData {
|
||||
// Debug mode: use local mock data
|
||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
return mockResponse
|
||||
} else {
|
||||
// Production mode: fetch from network API
|
||||
return try await fetchDataFromAliAPI()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches data from the network API with fallback to local mock data
|
||||
*
|
||||
* @throws NetworkError if both network request and local fallback fail
|
||||
*/
|
||||
private func fetchDataFromAliAPI() async throws -> TBDataResponse {
|
||||
do {
|
||||
guard let url = URL(string: AliCDNURL) else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(from: url)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
httpResponse.statusCode == 200 else {
|
||||
throw NetworkError.invalidResponse
|
||||
}
|
||||
|
||||
let apiResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
return apiResponse
|
||||
|
||||
} catch {
|
||||
print("Network request failed: \(error). Falling back to local mock data.")
|
||||
|
||||
// Fallback to local mock data if network request fails
|
||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
return mockResponse
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads a model from the selected source with progress tracking
|
||||
*
|
||||
* @param model The ModelInfo object containing model details
|
||||
* @param progress Progress callback that receives download progress (0.0 to 1.0)
|
||||
* @throws Various network or file system errors
|
||||
*/
|
||||
func downloadModel(model: ModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
switch ModelSourceManager.shared.selectedSource {
|
||||
case .modelScope, .modeler:
|
||||
try await downloadFromModelScope(model, progress: progress)
|
||||
case .huggingFace:
|
||||
try await downloadFromHuggingFace(model, progress: progress)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancels the current download operation
|
||||
*/
|
||||
func cancelDownload() async {
|
||||
if let manager = currentDownloadManager {
|
||||
await manager.cancelDownload()
|
||||
currentDownloadManager = nil
|
||||
print("Download cancelled")
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Downloads model from ModelScope platform
|
||||
*
|
||||
* @param model The ModelInfo object to download
|
||||
* @param progress Progress callback for download updates
|
||||
* @throws Download or network related errors
|
||||
*/
|
||||
private func downloadFromModelScope(_ model: ModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
let ModelScopeId = model.id
|
||||
let config = URLSessionConfiguration.default
|
||||
config.timeoutIntervalForRequest = 30
|
||||
config.timeoutIntervalForResource = 300
|
||||
|
||||
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
|
||||
currentDownloadManager = manager
|
||||
|
||||
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
|
||||
Task { @MainActor in
|
||||
progress(fileProgress)
|
||||
}
|
||||
}
|
||||
|
||||
currentDownloadManager = nil
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads model from HuggingFace platform with optimized progress updates
|
||||
*
|
||||
* This method implements throttling to prevent UI stuttering by limiting
|
||||
* progress update frequency and filtering out minor progress changes.
|
||||
*
|
||||
* @param model The ModelInfo object to download
|
||||
* @param progress Progress callback for download updates
|
||||
* @throws Download or network related errors
|
||||
*/
|
||||
private func downloadFromHuggingFace(_ model: ModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
let repo = Hub.Repo(id: model.id)
|
||||
let modelFiles = ["*.*"]
|
||||
let mirrorHubApi = HubApi(endpoint: baseURL)
|
||||
|
||||
// Progress throttling mechanism to prevent UI stuttering
|
||||
var lastUpdateTime = Date()
|
||||
var lastProgress: Double = 0.0
|
||||
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
|
||||
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
|
||||
|
||||
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
|
||||
let currentProgress = fileProgress.fractionCompleted
|
||||
let currentTime = Date()
|
||||
|
||||
// Check if progress should be updated
|
||||
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
|
||||
let progressDiff = abs(currentProgress - lastProgress)
|
||||
|
||||
// Update progress if any of these conditions are met:
|
||||
// 1. Time interval exceeds threshold
|
||||
// 2. Progress change exceeds threshold
|
||||
// 3. Progress reaches 100% (download complete)
|
||||
// 4. Progress is 0% (download start)
|
||||
if timeDiff >= progressUpdateInterval ||
|
||||
progressDiff >= progressThreshold ||
|
||||
currentProgress >= 1.0 ||
|
||||
currentProgress == 0.0 {
|
||||
|
||||
lastUpdateTime = currentTime
|
||||
lastProgress = currentProgress
|
||||
|
||||
// Ensure progress updates are executed on the main thread
|
||||
Task { @MainActor in
|
||||
progress(currentProgress)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
enum NetworkError: Error {
|
||||
case invalidResponse
|
||||
case invalidData
|
||||
case downloadFailed
|
||||
case unknown
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
//
|
||||
// ModelClient.swift
|
||||
// ModelScopeDownloadManager.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/2/20.
|
||||
|
@ -31,6 +31,11 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
private var downloadedSize: Int64 = 0
|
||||
private var lastUpdatedBytes: Int64 = 0
|
||||
|
||||
// Download cancellation related properties
|
||||
private var isCancelled: Bool = false
|
||||
private var currentDownloadTask: Task<Void, Error>?
|
||||
private var currentFileHandle: FileHandle?
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a new ModelScope download manager
|
||||
|
@ -83,6 +88,9 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
modelName: String,
|
||||
progress: ((Double) -> Void)? = nil
|
||||
) async throws {
|
||||
|
||||
isCancelled = false
|
||||
|
||||
ModelScopeLogger.info("Starting download for modelId: \(modelId)")
|
||||
|
||||
let destination = try resolveDestinationPath(base: destinationFolder, modelId: modelName)
|
||||
|
@ -100,7 +108,28 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
)
|
||||
}
|
||||
|
||||
// MARK: - Private Methods
|
||||
/// Cancel download
|
||||
/// Preserve downloaded temporary files to support resume functionality
|
||||
public func cancelDownload() async {
|
||||
isCancelled = true
|
||||
|
||||
currentDownloadTask?.cancel()
|
||||
currentDownloadTask = nil
|
||||
|
||||
await closeFileHandle()
|
||||
|
||||
session.invalidateAndCancel()
|
||||
|
||||
ModelScopeLogger.info("Download cancelled, temporary files preserved for resume")
|
||||
}
|
||||
|
||||
// MARK: - Private Methods - Progress Management
|
||||
|
||||
private func updateProgress(_ progress: Double, callback: @escaping (Double) -> Void) {
|
||||
Task { @MainActor in
|
||||
callback(progress)
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchFileList(
|
||||
root: String,
|
||||
|
@ -131,6 +160,8 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
var lastError: Error?
|
||||
|
||||
for attempt in 1...maxRetries {
|
||||
if isCancelled { break }
|
||||
|
||||
do {
|
||||
print("Attempt \(attempt) of \(maxRetries) for file: \(file.name)")
|
||||
try await downloadFileWithRetry(
|
||||
|
@ -162,6 +193,11 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
destinationPath: String,
|
||||
onProgress: @escaping (Int64) -> Void
|
||||
) async throws {
|
||||
|
||||
if isCancelled {
|
||||
throw ModelScopeError.downloadCancelled
|
||||
}
|
||||
|
||||
let session = self.session
|
||||
|
||||
ModelScopeLogger.info("Starting download for file: \(file.name)")
|
||||
|
@ -209,28 +245,45 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
ModelScopeLogger.debug("Requesting URL: \(url)")
|
||||
|
||||
return try await withCheckedThrowingContinuation { continuation in
|
||||
Task {
|
||||
currentDownloadTask = Task {
|
||||
do {
|
||||
let (asyncBytes, response) = try await session.bytes(for: request)
|
||||
ModelScopeLogger.debug("Response status code: \((response as? HTTPURLResponse)?.statusCode ?? -1)")
|
||||
try validateResponse(response)
|
||||
|
||||
let fileHandle = try FileHandle(forWritingTo: tempURL)
|
||||
self.currentFileHandle = fileHandle
|
||||
|
||||
if resumeOffset > 0 {
|
||||
try fileHandle.seek(toOffset: UInt64(resumeOffset))
|
||||
}
|
||||
|
||||
var downloadedBytes: Int64 = resumeOffset
|
||||
var bytesCount = 0
|
||||
|
||||
for try await byte in asyncBytes {
|
||||
// Frequently check cancellation status
|
||||
if isCancelled {
|
||||
try fileHandle.close()
|
||||
self.currentFileHandle = nil
|
||||
// Don't delete temp files when cancelled, preserve resume functionality
|
||||
continuation.resume(throwing: ModelScopeError.downloadCancelled)
|
||||
return
|
||||
}
|
||||
|
||||
try fileHandle.write(contentsOf: [byte])
|
||||
downloadedBytes += 1
|
||||
if downloadedBytes % 1024 == 0 {
|
||||
bytesCount += 1
|
||||
|
||||
// 减少进度回调频率:每 64KB * 5 更新一次而不是每1KB
|
||||
if bytesCount >= 64 * 1024 * 5 {
|
||||
onProgress(downloadedBytes)
|
||||
bytesCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
try fileHandle.close()
|
||||
self.currentFileHandle = nil
|
||||
|
||||
let finalSize = try FileManager.default.attributesOfItem(atPath: tempURL.path)[.size] as? Int64 ?? 0
|
||||
guard finalSize == file.size else {
|
||||
|
@ -250,8 +303,16 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
onProgress(downloadedBytes)
|
||||
continuation.resume()
|
||||
} catch {
|
||||
ModelScopeLogger.error("Download failed: \(error.localizedDescription)")
|
||||
storage.clearFileStatus(at: destination.path)
|
||||
// Clean up file handle when handling errors
|
||||
if let handle = self.currentFileHandle {
|
||||
try? handle.close()
|
||||
self.setCurrentFileHandle(nil)
|
||||
}
|
||||
|
||||
if !isCancelled {
|
||||
ModelScopeLogger.error("Download failed: \(error.localizedDescription)")
|
||||
storage.clearFileStatus(at: destination.path)
|
||||
}
|
||||
continuation.resume(throwing: error)
|
||||
}
|
||||
}
|
||||
|
@ -266,6 +327,10 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
) async throws {
|
||||
ModelScopeLogger.info("Starting download with \(files.count) files")
|
||||
|
||||
if isCancelled {
|
||||
throw ModelScopeError.downloadCancelled
|
||||
}
|
||||
|
||||
func calculateTotalSize(files: [ModelFile]) async throws -> Int64 {
|
||||
var size: Int64 = 0
|
||||
for file in files {
|
||||
|
@ -288,6 +353,11 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
}
|
||||
|
||||
for file in files {
|
||||
|
||||
if Task.isCancelled || isCancelled {
|
||||
throw ModelScopeError.downloadCancelled
|
||||
}
|
||||
|
||||
ModelScopeLogger.debug("Processing: \(file.name), type: \(file.type)")
|
||||
|
||||
if file.type == "tree" {
|
||||
|
@ -317,15 +387,7 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
destinationPath: destinationPath,
|
||||
onProgress: { downloadedBytes in
|
||||
let currentProgress = Double(self.downloadedSize + downloadedBytes) / Double(self.totalSize)
|
||||
progress(currentProgress)
|
||||
// 1MB = 1,024 * 1,024
|
||||
let bytesDelta = self.downloadedSize - self.lastUpdatedBytes
|
||||
if bytesDelta >= 1_024 * 1_024 {
|
||||
self.lastUpdatedBytes = self.downloadedSize
|
||||
DispatchQueue.main.async {
|
||||
progress(currentProgress)
|
||||
}
|
||||
}
|
||||
self.updateProgress(currentProgress, callback: progress)
|
||||
},
|
||||
maxRetries: 500,
|
||||
retryDelay: 1.0
|
||||
|
@ -340,9 +402,42 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
ModelScopeLogger.debug("File exists: \(file.name)")
|
||||
}
|
||||
|
||||
progress(Double(downloadedSize) / Double(totalSize))
|
||||
let currentProgress = Double(downloadedSize) / Double(totalSize)
|
||||
updateProgress(currentProgress, callback: progress)
|
||||
}
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
progress(1.0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private func resetDownloadState() async {
|
||||
totalFiles = 0
|
||||
downloadedFiles = 0
|
||||
totalSize = 0
|
||||
downloadedSize = 0
|
||||
lastUpdatedBytes = 0
|
||||
}
|
||||
|
||||
private func resetCancelStatus() {
|
||||
isCancelled = false
|
||||
|
||||
totalFiles = 0
|
||||
downloadedFiles = 0
|
||||
totalSize = 0
|
||||
downloadedSize = 0
|
||||
lastUpdatedBytes = 0
|
||||
}
|
||||
|
||||
private func closeFileHandle() async {
|
||||
do {
|
||||
try currentFileHandle?.close()
|
||||
currentFileHandle = nil
|
||||
} catch {
|
||||
print("Error closing file handle: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private func buildURL(
|
||||
|
@ -410,4 +505,27 @@ public actor ModelScopeDownloadManager: Sendable {
|
|||
|
||||
return modelScopePath.path
|
||||
}
|
||||
|
||||
private func setCurrentFileHandle(_ handle: FileHandle?) {
|
||||
currentFileHandle = handle
|
||||
}
|
||||
|
||||
private func getTempFileSize(for file: ModelFile, destinationPath: String) -> Int64 {
|
||||
let modelHash = repoPath.hash
|
||||
let fileHash = file.path.hash
|
||||
let tempURL = FileManager.default.temporaryDirectory
|
||||
.appendingPathComponent("model_\(modelHash)_file_\(fileHash)_\(file.name.sanitizedPath).tmp")
|
||||
|
||||
guard fileManager.fileExists(atPath: tempURL.path) else {
|
||||
return 0
|
||||
}
|
||||
|
||||
do {
|
||||
let attributes = try fileManager.attributesOfItem(atPath: tempURL.path)
|
||||
return attributes[.size] as? Int64 ?? 0
|
||||
} catch {
|
||||
ModelScopeLogger.error("Failed to get temp file size for \(file.name): \(error)")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ import Foundation
|
|||
public enum ModelScopeError: Error {
|
||||
case invalidURL
|
||||
case invalidResponse
|
||||
case downloadCancelled
|
||||
case downloadFailed(Error)
|
||||
case fileSystemError(Error)
|
||||
case invalidData
|
|
@ -0,0 +1,65 @@
|
|||
//
|
||||
// CustomPopupMenu.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/6/30.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct CustomPopupMenu: View {
|
||||
@Binding var isPresented: Bool
|
||||
@Binding var selectedSource: ModelSource
|
||||
let anchorFrame: CGRect
|
||||
|
||||
var body: some View {
|
||||
GeometryReader { geometry in
|
||||
ZStack(alignment: .top) {
|
||||
|
||||
Color.black.opacity(0.3)
|
||||
.frame(maxWidth: .infinity)
|
||||
.frame(height: UIScreen.main.bounds.height - anchorFrame.maxY)
|
||||
.offset(y: anchorFrame.maxY - 10)
|
||||
.onTapGesture {
|
||||
isPresented = false
|
||||
}
|
||||
|
||||
VStack(spacing: 0) {
|
||||
ForEach(ModelSource.allCases) { source in
|
||||
Button {
|
||||
selectedSource = source
|
||||
ModelSourceManager.shared.updateSelectedSource(source)
|
||||
isPresented = false
|
||||
} label: {
|
||||
HStack {
|
||||
Text(source.description)
|
||||
.font(.system(size: 12, weight: .regular))
|
||||
.foregroundColor(source == selectedSource ? .primaryBlue : .black)
|
||||
Spacer()
|
||||
if source == selectedSource {
|
||||
Image(systemName: "checkmark.circle")
|
||||
.foregroundColor(.primaryBlue)
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding()
|
||||
.background(.white)
|
||||
}
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
.background(Color.white)
|
||||
.cornerRadius(8)
|
||||
.shadow(color: .black.opacity(0.1), radius: 5, x: 0, y: 5)
|
||||
.frame(width: geometry.size.width)
|
||||
.position(
|
||||
x: geometry.size.width / 2,
|
||||
y: anchorFrame.maxY - 24
|
||||
)
|
||||
}
|
||||
}
|
||||
.transition(.opacity)
|
||||
.animation(.spring(response: 0.3, dampingFraction: 0.8, blendDuration: 0), value: isPresented)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
//
|
||||
// FilterButton.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct FilterButton: View {
|
||||
@Binding var showFilterMenu: Bool
|
||||
@Binding var selectedTags: Set<String>
|
||||
@Binding var selectedCategories: Set<String>
|
||||
@Binding var selectedVendors: Set<String>
|
||||
|
||||
var body: some View {
|
||||
Button(action: {
|
||||
showFilterMenu.toggle()
|
||||
}) {
|
||||
Image(systemName: "line.3.horizontal.decrease.circle")
|
||||
.font(.system(size: 20))
|
||||
.foregroundColor(.primary)
|
||||
}
|
||||
.sheet(isPresented: $showFilterMenu) {
|
||||
FilterMenuView(
|
||||
selectedTags: $selectedTags,
|
||||
selectedCategories: $selectedCategories,
|
||||
selectedVendors: $selectedVendors
|
||||
)
|
||||
.presentationDetents([.large])
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
//
|
||||
// FilterMenuView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct FilterMenuView: View {
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
@StateObject private var viewModel = ModelListViewModel()
|
||||
@Binding var selectedTags: Set<String>
|
||||
@Binding var selectedCategories: Set<String>
|
||||
@Binding var selectedVendors: Set<String>
|
||||
|
||||
var body: some View {
|
||||
NavigationView {
|
||||
ScrollView {
|
||||
VStack(alignment: .leading, spacing: 24) {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
Text("filter.byTag")
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
|
||||
LazyVGrid(columns: Array(repeating: GridItem(.flexible()), count: 2), spacing: 8) {
|
||||
ForEach(viewModel.allTags.sorted(), id: \.self) { tag in
|
||||
FilterOptionRow(
|
||||
text: TagTranslationManager.shared.getLocalizedTag(tag),
|
||||
isSelected: selectedTags.contains(tag)
|
||||
) {
|
||||
if selectedTags.contains(tag) {
|
||||
selectedTags.remove(tag)
|
||||
} else {
|
||||
selectedTags.insert(tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Divider()
|
||||
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
Text("filter.byVendor")
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
|
||||
LazyVGrid(columns: Array(repeating: GridItem(.flexible()), count: 2), spacing: 8) {
|
||||
ForEach(viewModel.allVendors.sorted(), id: \.self) { vendor in
|
||||
FilterOptionRow(
|
||||
text: vendor,
|
||||
isSelected: selectedVendors.contains(vendor)
|
||||
) {
|
||||
if selectedVendors.contains(vendor) {
|
||||
selectedVendors.remove(vendor)
|
||||
} else {
|
||||
selectedVendors.insert(vendor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(minLength: 100)
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
.navigationTitle("filter.title")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .navigationBarLeading) {
|
||||
Button("button.clear") {
|
||||
selectedTags.removeAll()
|
||||
selectedCategories.removeAll()
|
||||
selectedVendors.removeAll()
|
||||
}
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .navigationBarTrailing) {
|
||||
Button("button.done") {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.onAppear {
|
||||
Task {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
//
|
||||
// FilterOptionRow.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct FilterOptionRow: View {
|
||||
let text: String
|
||||
let isSelected: Bool
|
||||
let onTap: () -> Void
|
||||
|
||||
var body: some View {
|
||||
Button(action: onTap) {
|
||||
HStack {
|
||||
Text(text)
|
||||
.font(.system(size: 14))
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Spacer()
|
||||
|
||||
if isSelected {
|
||||
Image(systemName: "checkmark.circle.fill")
|
||||
.foregroundColor(.accentColor)
|
||||
} else {
|
||||
Image(systemName: "circle")
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(isSelected ? Color.accentColor.opacity(0.1) : Color(.systemGray6))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
//
|
||||
// FilterTagChip.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct FilterTagChip: View {
|
||||
let text: String
|
||||
let isSelected: Bool
|
||||
let onTap: () -> Void
|
||||
|
||||
var body: some View {
|
||||
Button(action: onTap) {
|
||||
Text(text)
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
.foregroundColor(isSelected ? .white : .primary)
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 6)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(isSelected ? Color.primaryPurple : Color(.systemGray6))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
//
|
||||
// QuickFilterTags.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct QuickFilterTags: View {
|
||||
let tags: [String]
|
||||
@Binding var selectedTags: Set<String>
|
||||
|
||||
var body: some View {
|
||||
ScrollView(.horizontal, showsIndicators: false) {
|
||||
HStack(spacing: 8) {
|
||||
ForEach(tags, id: \.self) { tag in
|
||||
FilterTagChip(
|
||||
text: TagTranslationManager.shared.getLocalizedTag(tag),
|
||||
isSelected: selectedTags.contains(tag)
|
||||
) {
|
||||
if selectedTags.contains(tag) {
|
||||
selectedTags.remove(tag)
|
||||
} else {
|
||||
selectedTags.insert(tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
//
|
||||
// SourceSelector.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct SourceSelector: View {
|
||||
@Binding var selectedSource: ModelSource
|
||||
@Binding var showSourceMenu: Bool
|
||||
let onSourceChange: (ModelSource) -> Void
|
||||
|
||||
var body: some View {
|
||||
Menu {
|
||||
ForEach(ModelSource.allCases) { source in
|
||||
Button(action: {
|
||||
onSourceChange(source)
|
||||
}) {
|
||||
HStack {
|
||||
Text(source.rawValue)
|
||||
if source == selectedSource {
|
||||
Image(systemName: "checkmark")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} label: {
|
||||
HStack(spacing: 4) {
|
||||
Text("modelSource.title")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Text(selectedSource.rawValue)
|
||||
.font(.system(size: 12, weight: .regular))
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Image(systemName: "chevron.down")
|
||||
.font(.system(size: 10))
|
||||
.foregroundColor(.primary)
|
||||
}
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 6)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
//
|
||||
// ToolbarView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct ToolbarView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
@Binding var selectedSource: ModelSource
|
||||
@Binding var showSourceMenu: Bool
|
||||
@Binding var selectedTags: Set<String>
|
||||
@Binding var selectedCategories: Set<String>
|
||||
@Binding var selectedVendors: Set<String>
|
||||
let quickFilterTags: [String]
|
||||
@Binding var showFilterMenu: Bool
|
||||
let onSourceChange: (ModelSource) -> Void
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 12) {
|
||||
HStack {
|
||||
SourceSelector(
|
||||
selectedSource: $selectedSource,
|
||||
showSourceMenu: $showSourceMenu,
|
||||
onSourceChange: onSourceChange
|
||||
)
|
||||
|
||||
// 快捷筛选标签
|
||||
QuickFilterTags(
|
||||
tags: quickFilterTags,
|
||||
selectedTags: $selectedTags
|
||||
)
|
||||
|
||||
Spacer()
|
||||
|
||||
FilterButton(
|
||||
showFilterMenu: $showFilterMenu,
|
||||
selectedTags: $selectedTags,
|
||||
selectedCategories: $selectedCategories,
|
||||
selectedVendors: $selectedVendors
|
||||
)
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
.background(Color(.systemBackground))
|
||||
.overlay(
|
||||
Rectangle()
|
||||
.frame(height: 0.5)
|
||||
.foregroundColor(Color(.separator)),
|
||||
alignment: .bottom
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
//
|
||||
// ModelListView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct ModelListView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
@State private var searchText = ""
|
||||
@State private var selectedSource = ModelSourceManager.shared.selectedSource
|
||||
@State private var showSourceMenu = false
|
||||
@State private var selectedTags: Set<String> = []
|
||||
@State private var selectedCategories: Set<String> = []
|
||||
@State private var selectedVendors: Set<String> = []
|
||||
@State private var showFilterMenu = false
|
||||
|
||||
var body: some View {
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
|
||||
Section {
|
||||
modelListSection
|
||||
} header: {
|
||||
toolbarSection
|
||||
}
|
||||
}
|
||||
}
|
||||
.searchable(text: $searchText, prompt: "Search models...")
|
||||
.onChange(of: searchText) { _, newValue in
|
||||
viewModel.searchText = newValue
|
||||
}
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
.alert("Error", isPresented: $viewModel.showError) {
|
||||
Button("OK") { }
|
||||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract model list section as independent view
|
||||
@ViewBuilder
|
||||
private var modelListSection: some View {
|
||||
LazyVStack(spacing: 8) {
|
||||
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
|
||||
modelRowView(model: model, index: index)
|
||||
|
||||
if index < filteredModels.count - 1 {
|
||||
Divider()
|
||||
.padding(.leading, 60)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
|
||||
// Extract toolbar section as independent view
|
||||
@ViewBuilder
|
||||
private var toolbarSection: some View {
|
||||
ToolbarView(
|
||||
viewModel: viewModel, selectedSource: $selectedSource,
|
||||
showSourceMenu: $showSourceMenu,
|
||||
selectedTags: $selectedTags,
|
||||
selectedCategories: $selectedCategories,
|
||||
selectedVendors: $selectedVendors,
|
||||
quickFilterTags: viewModel.quickFilterTags,
|
||||
showFilterMenu: $showFilterMenu,
|
||||
onSourceChange: handleSourceChange
|
||||
)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func modelRowView(model: ModelInfo, index: Int) -> some View {
|
||||
ModelRowView(
|
||||
model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
|
||||
isDownloading: viewModel.currentlyDownloading == model.id,
|
||||
isOtherDownloading: isOtherDownloadingCheck(model: model)
|
||||
) {
|
||||
Task {
|
||||
await viewModel.downloadModel(model)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
|
||||
// Extract complex boolean logic as independent method
|
||||
private func isOtherDownloadingCheck(model: ModelInfo) -> Bool {
|
||||
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
|
||||
}
|
||||
|
||||
// Extract source change handling logic as independent method
|
||||
private func handleSourceChange(_ source: ModelSource) {
|
||||
ModelSourceManager.shared.updateSelectedSource(source)
|
||||
selectedSource = source
|
||||
Task {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
}
|
||||
|
||||
// Filter models based on selected tags, categories and vendors
|
||||
private var filteredModels: [ModelInfo] {
|
||||
let baseFiltered = viewModel.filteredModels
|
||||
|
||||
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
|
||||
return baseFiltered
|
||||
}
|
||||
|
||||
return baseFiltered.filter { model in
|
||||
let tagMatch = checkTagMatch(model: model)
|
||||
let categoryMatch = checkCategoryMatch(model: model)
|
||||
let vendorMatch = checkVendorMatch(model: model)
|
||||
|
||||
return tagMatch && categoryMatch && vendorMatch
|
||||
}
|
||||
}
|
||||
|
||||
// Extract tag matching logic as independent method
|
||||
private func checkTagMatch(model: ModelInfo) -> Bool {
|
||||
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
|
||||
model.localizedTags.contains { tag in
|
||||
tag.localizedCaseInsensitiveContains(selectedTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract category matching logic as independent method
|
||||
private func checkCategoryMatch(model: ModelInfo) -> Bool {
|
||||
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
|
||||
model.categories?.contains { category in
|
||||
category.localizedCaseInsensitiveContains(selectedCategory)
|
||||
} ?? false
|
||||
}
|
||||
}
|
||||
|
||||
// Extract vendor matching logic as independent method
|
||||
private func checkVendorMatch(model: ModelInfo) -> Bool {
|
||||
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
|
||||
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
//
|
||||
// ActionButtonsView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 操作按钮视图
|
||||
struct ActionButtonsView: View {
|
||||
let model: ModelInfo
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
let downloadProgress: Double
|
||||
let isDownloading: Bool
|
||||
let isOtherDownloading: Bool
|
||||
let formattedSize: String
|
||||
let onDownload: () -> Void
|
||||
@Binding var showDeleteAlert: Bool
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .center, spacing: 4) {
|
||||
if model.isDownloaded {
|
||||
// 已下载状态
|
||||
DownloadedButtonView(showDeleteAlert: $showDeleteAlert)
|
||||
} else if isDownloading {
|
||||
// 下载中状态
|
||||
DownloadingButtonView(
|
||||
viewModel: viewModel,
|
||||
downloadProgress: downloadProgress
|
||||
)
|
||||
} else {
|
||||
// 待下载状态
|
||||
PendingDownloadButtonView(
|
||||
isOtherDownloading: isOtherDownloading,
|
||||
formattedSize: formattedSize,
|
||||
onDownload: onDownload
|
||||
)
|
||||
}
|
||||
}
|
||||
.frame(width: 60)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
//
|
||||
// DownloadedButtonView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 已下载按钮视图
|
||||
struct DownloadedButtonView: View {
|
||||
@Binding var showDeleteAlert: Bool
|
||||
|
||||
var body: some View {
|
||||
Button(action: { showDeleteAlert = true }) {
|
||||
VStack(spacing: 2) {
|
||||
Image(systemName: "trash")
|
||||
.font(.system(size: 16))
|
||||
.foregroundColor(.primary.opacity(0.8))
|
||||
|
||||
Text(LocalizedStringKey("button.downloaded"))
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
.lineLimit(1)
|
||||
.minimumScaleFactor(0.8)
|
||||
.allowsTightening(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
//
|
||||
// DownloadingButtonView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 下载中按钮视图
|
||||
struct DownloadingButtonView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
let downloadProgress: Double
|
||||
|
||||
var body: some View {
|
||||
Button(action: {
|
||||
Task {
|
||||
await viewModel.cancelDownload()
|
||||
}
|
||||
}) {
|
||||
VStack(spacing: 2) {
|
||||
ProgressView(value: downloadProgress)
|
||||
.progressViewStyle(CircularProgressViewStyle(tint: .accentColor))
|
||||
.frame(width: 24, height: 24)
|
||||
|
||||
Text(String(format: "%.2f%%", downloadProgress * 100))
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
//
|
||||
// PendingDownloadButtonView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct PendingDownloadButtonView: View {
|
||||
let isOtherDownloading: Bool
|
||||
let formattedSize: String
|
||||
let onDownload: () -> Void
|
||||
|
||||
var body: some View {
|
||||
Button(action: onDownload) {
|
||||
Image(systemName: "arrow.down.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundColor(isOtherDownloading ? .secondary : .primaryPurple)
|
||||
}
|
||||
.disabled(isOtherDownloading)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
//
|
||||
// TagChip.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 标签芯片
|
||||
struct TagChip: View {
|
||||
let text: String
|
||||
|
||||
var body: some View {
|
||||
Text(TagTranslationManager.shared.getLocalizedTag(text))
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 3)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.secondary.opacity(0.3), lineWidth: 0.5)
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
//
|
||||
// TagsView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 标签视图
|
||||
struct TagsView: View {
|
||||
let tags: [String]
|
||||
|
||||
var body: some View {
|
||||
ScrollView(.horizontal, showsIndicators: false) {
|
||||
HStack(spacing: 6) {
|
||||
ForEach(tags, id: \.self) { tag in
|
||||
TagChip(text: tag)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 1)
|
||||
}
|
||||
.frame(height: 25)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
//
|
||||
// ModelRowView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct ModelRowView: View {
|
||||
|
||||
let model: ModelInfo
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
||||
let downloadProgress: Double
|
||||
let isDownloading: Bool
|
||||
let isOtherDownloading: Bool
|
||||
let onDownload: () -> Void
|
||||
|
||||
@State private var showDeleteAlert = false
|
||||
|
||||
private var localizedTags: [String] {
|
||||
model.localizedTags
|
||||
}
|
||||
|
||||
private var formattedSize: String {
|
||||
model.formattedSize
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .center, spacing: 0) {
|
||||
|
||||
ModelIconView(modelId: model.id)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
|
||||
Text(model.modelName)
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
.lineLimit(1)
|
||||
|
||||
if !localizedTags.isEmpty {
|
||||
TagsView(tags: localizedTags)
|
||||
}
|
||||
|
||||
HStack(alignment: .center, spacing: 2) {
|
||||
Image(systemName: "folder")
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
.frame(width: 20, height: 20)
|
||||
|
||||
Text(formattedSize)
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
|
||||
Spacer()
|
||||
|
||||
VStack {
|
||||
Spacer()
|
||||
ActionButtonsView(
|
||||
model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: downloadProgress,
|
||||
isDownloading: isDownloading,
|
||||
isOtherDownloading: isOtherDownloading,
|
||||
formattedSize: formattedSize,
|
||||
onDownload: onDownload,
|
||||
showDeleteAlert: $showDeleteAlert
|
||||
)
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
.contentShape(Rectangle())
|
||||
.onTapGesture {
|
||||
handleRowTap()
|
||||
}
|
||||
.alert(LocalizedStringKey("alert.deleteModel.title"), isPresented: $showDeleteAlert) {
|
||||
Button("Delete", role: .destructive) {
|
||||
Task {
|
||||
await viewModel.deleteModel(model)
|
||||
}
|
||||
}
|
||||
Button("Cancel", role: .cancel) { }
|
||||
} message: {
|
||||
Text(LocalizedStringKey("alert.deleteModel.message"))
|
||||
}
|
||||
}
|
||||
|
||||
private func handleRowTap() {
|
||||
if model.isDownloaded {
|
||||
return
|
||||
} else if isDownloading {
|
||||
Task {
|
||||
await viewModel.cancelDownload()
|
||||
}
|
||||
} else if !isOtherDownloading {
|
||||
onDownload()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,8 +14,10 @@ struct SearchBar: View {
|
|||
HStack {
|
||||
Image(systemName: "magnifyingglass")
|
||||
.foregroundColor(.gray)
|
||||
.padding(.horizontal, 10)
|
||||
|
||||
TextField("Search models...", text: $text)
|
||||
.font(.system(size: 12, weight: .regular))
|
||||
.textFieldStyle(RoundedBorderTextFieldStyle())
|
||||
.autocapitalization(.none)
|
||||
.disableAutocorrection(true)
|
|
@ -0,0 +1,40 @@
|
|||
//
|
||||
// SwipeActionsView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct SwipeActionsView: View {
|
||||
let model: ModelInfo
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
||||
var body: some View {
|
||||
if viewModel.pinnedModelIds.contains(model.id) {
|
||||
Button {
|
||||
viewModel.unpinModel(model)
|
||||
} label: {
|
||||
Label(LocalizedStringKey("button.unpin"), systemImage: "pin.slash")
|
||||
}.tint(.gray)
|
||||
} else {
|
||||
Button {
|
||||
viewModel.pinModel(model)
|
||||
} label: {
|
||||
Label(LocalizedStringKey("button.pin"), systemImage: "pin")
|
||||
}.tint(.primaryBlue)
|
||||
}
|
||||
if model.isDownloaded {
|
||||
Button(role: .destructive) {
|
||||
Task {
|
||||
await viewModel.deleteModel(model)
|
||||
}
|
||||
} label: {
|
||||
Label("Delete", systemImage: "trash")
|
||||
}
|
||||
.tint(.primaryRed)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
//
|
||||
// ModelDownloadStorage.swift
|
||||
// ModelSource.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/2/20.
|
||||
|
@ -7,11 +7,13 @@
|
|||
|
||||
import Foundation
|
||||
|
||||
public enum ModelSource: String, CaseIterable {
|
||||
public enum ModelSource: String, CaseIterable, Identifiable {
|
||||
case modelScope = "ModelScope"
|
||||
case huggingFace = "Hugging Face"
|
||||
case huggingFace = "HuggingFace"
|
||||
case modeler = "Modeler"
|
||||
|
||||
public var id: Self { self }
|
||||
|
||||
var description: String {
|
||||
switch self {
|
||||
case .modelScope:
|
|
@ -1,5 +1,5 @@
|
|||
//
|
||||
// ModelDownloadStorage.swift
|
||||
// ModelSourceManager.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/2/20.
|
|
@ -0,0 +1,101 @@
|
|||
//
|
||||
// ModelStorageManager.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/10.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
class ModelStorageManager {
|
||||
static let shared = ModelStorageManager()
|
||||
|
||||
private let userDefaults = UserDefaults.standard
|
||||
private let downloadedModelsKey = "com.mnnllm.downloadedModels"
|
||||
private let lastUsedModelKey = "com.mnnllm.lastUsedModels"
|
||||
private let cachedSizesKey = "com.mnnllm.cachedSizes"
|
||||
|
||||
private init() {}
|
||||
|
||||
var lastUsedModels: [String: Date] {
|
||||
get {
|
||||
userDefaults.dictionary(forKey: lastUsedModelKey) as? [String: Date] ?? [:]
|
||||
}
|
||||
set {
|
||||
userDefaults.set(newValue, forKey: lastUsedModelKey)
|
||||
}
|
||||
}
|
||||
|
||||
func updateLastUsed(for modelName: String) {
|
||||
var models = lastUsedModels
|
||||
models[modelName] = Date()
|
||||
lastUsedModels = models
|
||||
}
|
||||
|
||||
func getLastUsed(for modelName: String) -> Date? {
|
||||
return lastUsedModels[modelName]
|
||||
}
|
||||
|
||||
var downloadedModels: [String] {
|
||||
get {
|
||||
userDefaults.array(forKey: downloadedModelsKey) as? [String] ?? []
|
||||
}
|
||||
set {
|
||||
userDefaults.set(newValue, forKey: downloadedModelsKey)
|
||||
}
|
||||
}
|
||||
|
||||
func clearDownloadStatus(for modelName: String) {
|
||||
var models = downloadedModels
|
||||
models.removeAll { $0 == modelName }
|
||||
downloadedModels = models
|
||||
}
|
||||
|
||||
func isModelDownloaded(_ modelName: String) -> Bool {
|
||||
downloadedModels.contains(modelName)
|
||||
}
|
||||
|
||||
func markModelAsDownloaded(_ modelName: String) {
|
||||
var models = downloadedModels
|
||||
if !models.contains(modelName) {
|
||||
models.append(modelName)
|
||||
downloadedModels = models
|
||||
}
|
||||
}
|
||||
|
||||
func markModelAsNotDownloaded(_ modelName: String) {
|
||||
var models = downloadedModels
|
||||
models.removeAll { $0 == modelName }
|
||||
downloadedModels = models
|
||||
|
||||
// Also clear cached size when model is marked as not downloaded
|
||||
clearCachedSize(for: modelName)
|
||||
}
|
||||
|
||||
// MARK: - Cached Size Management
|
||||
|
||||
var cachedSizes: [String: Int64] {
|
||||
get {
|
||||
userDefaults.dictionary(forKey: cachedSizesKey) as? [String: Int64] ?? [:]
|
||||
}
|
||||
set {
|
||||
userDefaults.set(newValue, forKey: cachedSizesKey)
|
||||
}
|
||||
}
|
||||
|
||||
func getCachedSize(for modelName: String) -> Int64? {
|
||||
return cachedSizes[modelName]
|
||||
}
|
||||
|
||||
func setCachedSize(_ size: Int64, for modelName: String) {
|
||||
var sizes = cachedSizes
|
||||
sizes[modelName] = size
|
||||
cachedSizes = sizes
|
||||
}
|
||||
|
||||
func clearCachedSize(for modelName: String) {
|
||||
var sizes = cachedSizes
|
||||
sizes.removeValue(forKey: modelName)
|
||||
cachedSizes = sizes
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
//
|
||||
// TagTranslationManager.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
class TagTranslationManager {
|
||||
static let shared = TagTranslationManager()
|
||||
private var tagTranslations: [String: String] = [:]
|
||||
|
||||
private init() {}
|
||||
|
||||
func loadTagTranslations(_ translations: [String: String]) {
|
||||
tagTranslations = translations
|
||||
}
|
||||
|
||||
func getLocalizedTag(_ tag: String) -> String {
|
||||
let currentLanguage = LanguageManager.shared.currentLanguage
|
||||
let isChineseLanguage = currentLanguage == "zh-Hans"
|
||||
|
||||
if isChineseLanguage, let translation = tagTranslations[tag] {
|
||||
return translation
|
||||
}
|
||||
return tag
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
//
|
||||
// ModelClient.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
//
|
||||
|
||||
import Hub
|
||||
import Foundation
|
||||
|
||||
struct ModelInfo: Codable {
|
||||
let modelId: String
|
||||
let createdAt: String
|
||||
let downloads: Int
|
||||
let tags: [String]
|
||||
|
||||
var name: String {
|
||||
modelId.removingTaobaoPrefix()
|
||||
}
|
||||
|
||||
var isDownloaded: Bool = false
|
||||
|
||||
var localPath: String {
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelId)).path
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case modelId
|
||||
case tags
|
||||
case downloads
|
||||
case createdAt
|
||||
}
|
||||
}
|
||||
|
||||
struct RepoInfo: Codable {
|
||||
let modelId: String
|
||||
let sha: String
|
||||
let siblings: [Sibling]
|
||||
|
||||
struct Sibling: Codable {
|
||||
let rfilename: String
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue