mirror of https://github.com/alibaba/MNN.git
Compare commits
24 Commits
638a091d2c
...
310eaf9250
Author | SHA1 | Date |
---|---|---|
|
310eaf9250 | |
|
5952e33570 | |
|
a97857ef4c | |
|
136cfe74b7 | |
|
b373b8c297 | |
|
01693ab81c | |
|
d588ce0d51 | |
|
95dde21a1e | |
|
7d9ea467bc | |
|
7a92132554 | |
|
dea02b4b8a | |
|
4fea0b4c76 | |
|
ba603df644 | |
|
27442d0ccf | |
|
d1913d8111 | |
|
380370fcf8 | |
|
36f087d946 | |
|
90dba5548e | |
|
8a0f12591f | |
|
1233a8f87f | |
|
20e62dde9b | |
|
2d589fc032 | |
|
8d11c13b44 | |
|
c6336fa6b0 |
|
@ -77,11 +77,6 @@
|
|||
path = MNNLLMiOS;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
3EAD33932D2BF38800021CF7 /* LocalModel */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = LocalModel;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedRootGroup section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
|
@ -117,7 +112,6 @@
|
|||
isa = PBXGroup;
|
||||
children = (
|
||||
3E94DF1F2D30CCF900BE39A7 /* MNNLLMiOS */,
|
||||
3EAD33932D2BF38800021CF7 /* LocalModel */,
|
||||
3E8592122D1D45090067B46F /* MNNLLMiOSTests */,
|
||||
3E85921C2D1D45090067B46F /* MNNLLMiOSUITests */,
|
||||
3EAD33842D2BDDC300021CF7 /* Frameworks */,
|
||||
|
@ -160,7 +154,6 @@
|
|||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
3E94DF1F2D30CCF900BE39A7 /* MNNLLMiOS */,
|
||||
3EAD33932D2BF38800021CF7 /* LocalModel */,
|
||||
);
|
||||
name = MNNLLMiOS;
|
||||
packageProductDependencies = (
|
||||
|
|
|
@ -1,214 +0,0 @@
|
|||
# UI Performance Optimization Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This document explains the performance optimization utilities implemented in the chat system to ensure smooth AI streaming text output and overall UI responsiveness.
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. PerformanceMonitor
|
||||
|
||||
A singleton utility for real-time performance monitoring and measurement.
|
||||
|
||||
#### Implementation Principles
|
||||
|
||||
- **Real-time FPS Monitoring**: Tracks UI update frequency and calculates actual FPS
|
||||
- **Frame Drop Detection**: Identifies when UI updates exceed the 16.67ms threshold (60 FPS)
|
||||
- **Operation Time Measurement**: Measures execution time of specific operations
|
||||
- **Automatic Statistics Reporting**: Logs performance metrics every second
|
||||
|
||||
#### Key Features
|
||||
|
||||
```swift
|
||||
class PerformanceMonitor {
|
||||
static let shared = PerformanceMonitor()
|
||||
|
||||
// Performance thresholds
|
||||
private let targetFPS: Double = 60.0
|
||||
private let frameThreshold: TimeInterval = 1.0 / 60.0 * 1.5 // 25ms threshold
|
||||
|
||||
func recordUIUpdate() { /* Track UI updates */ }
|
||||
func measureExecutionTime<T>(operation: String, block: () throws -> T) { /* Measure operations */ }
|
||||
}
|
||||
```
|
||||
|
||||
#### Usage in the Chat System
|
||||
|
||||
- Integrated into `LLMChatInteractor` to monitor message updates
|
||||
- Tracks UI update frequency during AI text streaming
|
||||
- Identifies performance bottlenecks in real-time
|
||||
|
||||
### 2. UIUpdateOptimizer
|
||||
|
||||
An actor-based utility for batching and throttling UI updates during streaming scenarios.
|
||||
|
||||
#### Implementation Principles
|
||||
|
||||
- **Batching Mechanism**: Groups multiple small updates into larger, more efficient ones
|
||||
- **Time-based Throttling**: Limits update frequency to prevent UI overload
|
||||
- **Actor-based Thread Safety**: Ensures safe concurrent access to update queue
|
||||
- **Automatic Flush Strategy**: Intelligently decides when to apply batched updates
|
||||
|
||||
#### Architecture
|
||||
|
||||
```swift
|
||||
actor UIUpdateOptimizer {
|
||||
static let shared = UIUpdateOptimizer()
|
||||
|
||||
private var pendingUpdates: [String] = []
|
||||
private let batchSize: Int = 5 // Batch threshold
|
||||
private let flushInterval: TimeInterval = 0.03 // 30ms throttling
|
||||
|
||||
func addUpdate(_ content: String, completion: @escaping (String) -> Void)
|
||||
func forceFlush(completion: @escaping (String) -> Void)
|
||||
}
|
||||
```
|
||||
|
||||
#### Optimization Strategies
|
||||
|
||||
1. **Batch Size Control**: Groups up to 5 updates before flushing
|
||||
2. **Time-based Throttling**: Flushes updates every 30ms maximum
|
||||
3. **Intelligent Scheduling**: Cancels redundant flush operations
|
||||
4. **Main Thread Delegation**: Ensures UI updates occur on the main thread
|
||||
|
||||
#### Integration Points
|
||||
|
||||
- **LLM Streaming**: Optimizes real-time text output from AI models
|
||||
- **Message Updates**: Batches frequent message content changes
|
||||
- **Force Flush**: Ensures final content is displayed when streaming ends
|
||||
|
||||
## Performance Optimization Flow
|
||||
|
||||
```
|
||||
AI Model Output → UIUpdateOptimizer → Batched Updates → UI Thread → Display
|
||||
↓
|
||||
PerformanceMonitor (Monitoring)
|
||||
↓
|
||||
Console Logs (Metrics)
|
||||
```
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
### Performance Metrics
|
||||
|
||||
1. **Target Performance**:
|
||||
- Maintain 50+ FPS during streaming
|
||||
- Keep frame drop rate below 5%
|
||||
- Single operations under 16ms
|
||||
|
||||
2. **Monitoring Indicators**:
|
||||
- `📊 Performance Stats` - Real-time FPS and drop rate
|
||||
- `⚠️ UI Update Lag detected` - Frame drop warnings
|
||||
- `⏱️ Slow Operation` - Operation time alerts
|
||||
|
||||
### Testing Methodology
|
||||
|
||||
1. **Streaming Tests**:
|
||||
- Test with long-form AI responses (articles, code)
|
||||
- Monitor console output for performance warnings
|
||||
- Observe visual smoothness of text animation
|
||||
|
||||
2. **Load Testing**:
|
||||
- Rapid successive message sending
|
||||
- Large text blocks processing
|
||||
- Multiple concurrent operations
|
||||
|
||||
3. **Comparative Analysis**:
|
||||
- Before/after optimization measurements
|
||||
- Different device performance profiles
|
||||
- Various content types and sizes
|
||||
|
||||
### Debug Configuration
|
||||
|
||||
For development and testing purposes:
|
||||
|
||||
```swift
|
||||
// Example configuration adjustments (not implemented in production)
|
||||
// UIUpdateOptimizer.shared.batchSize = 10
|
||||
// UIUpdateOptimizer.shared.flushInterval = 0.05
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### UIUpdateOptimizer Algorithm
|
||||
|
||||
1. **Add Update**: New content is appended to pending queue
|
||||
2. **Threshold Check**: Evaluate if immediate flush is needed
|
||||
- Batch size reached (≥5 updates)
|
||||
- Time threshold exceeded (≥30ms since last flush)
|
||||
3. **Scheduling**: If not immediate, schedule delayed flush
|
||||
4. **Flush Execution**: Combine all pending updates and execute on main thread
|
||||
5. **Cleanup**: Clear queue and reset timing
|
||||
|
||||
### PerformanceMonitor Algorithm
|
||||
|
||||
1. **Update Recording**: Track each UI update call
|
||||
2. **Timing Analysis**: Calculate time difference between updates
|
||||
3. **Frame Drop Detection**: Compare against 25ms threshold
|
||||
4. **Statistics Calculation**: Compute FPS and drop rate every second
|
||||
5. **Logging**: Output performance metrics to console
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### In ViewModels
|
||||
|
||||
```swift
|
||||
func updateUI() {
|
||||
PerformanceMonitor.shared.recordUIUpdate()
|
||||
// UI update code here
|
||||
}
|
||||
|
||||
let result = PerformanceMonitor.shared.measureExecutionTime(operation: "Data Processing") {
|
||||
return processLargeDataSet()
|
||||
}
|
||||
```
|
||||
|
||||
### In Streaming Scenarios
|
||||
|
||||
```swift
|
||||
await UIUpdateOptimizer.shared.addUpdate(newText) { batchedContent in
|
||||
// Update UI with optimized batched content
|
||||
updateTextView(with: batchedContent)
|
||||
}
|
||||
|
||||
// When stream ends
|
||||
await UIUpdateOptimizer.shared.forceFlush { finalContent in
|
||||
finalizeTextDisplay(with: finalContent)
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Performance Issues
|
||||
|
||||
1. **High Frame Drop Rate**:
|
||||
- Check for blocking operations on main thread
|
||||
- Verify batch size configuration
|
||||
- Monitor memory usage
|
||||
|
||||
2. **Slow Operation Warnings**:
|
||||
- Profile specific operations causing delays
|
||||
- Consider background threading for heavy tasks
|
||||
- Optimize data processing algorithms
|
||||
|
||||
3. **Inconsistent Performance**:
|
||||
- Check device thermal state
|
||||
- Monitor memory pressure
|
||||
- Verify background app activity
|
||||
|
||||
### Diagnostic Tools
|
||||
|
||||
- **Console Monitoring**: Watch for performance log messages
|
||||
- **Xcode Instruments**: Use Time Profiler for detailed analysis
|
||||
- **Memory Graph**: Check for memory leaks affecting performance
|
||||
- **Energy Impact**: Monitor battery and thermal effects
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Proactive Monitoring**: Always call `recordUIUpdate()` for critical UI operations
|
||||
2. **Batch When Possible**: Use `UIUpdateOptimizer` for frequent updates
|
||||
3. **Measure Critical Paths**: Wrap expensive operations with `measureExecutionTime`
|
||||
4. **Test on Real Devices**: Performance varies significantly across device types
|
||||
5. **Monitor in Production**: Keep performance logging enabled during development
|
||||
|
||||
This performance optimization system ensures smooth user experience during AI text generation while providing developers with the tools needed to maintain and improve performance over time.
|
|
@ -74,6 +74,11 @@ final class LLMChatViewModel: ObservableObject {
|
|||
|
||||
deinit {
|
||||
print("yxy:: LLMChat View Model deinit")
|
||||
|
||||
llm?.cancelInference()
|
||||
llm = nil
|
||||
diffusion = nil
|
||||
print("yxy:: LLMChat View Model cleanup complete")
|
||||
}
|
||||
|
||||
func setupLLM(modelPath: String) {
|
||||
|
@ -148,6 +153,9 @@ final class LLMChatViewModel: ObservableObject {
|
|||
}
|
||||
|
||||
func sendToLLM(draft: DraftMessage) {
|
||||
|
||||
NotificationCenter.default.post(name: .dismissKeyboard, object: nil)
|
||||
|
||||
self.send(draft: draft, userType: .user)
|
||||
if isModelLoaded {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
|
@ -258,6 +266,10 @@ final class LLMChatViewModel: ObservableObject {
|
|||
await MainActor.run {
|
||||
self.isProcessing = false
|
||||
self.currentStreamingMessageId = nil
|
||||
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
|
||||
NotificationCenter.default.post(name: .dismissKeyboard, object: nil)
|
||||
}
|
||||
}
|
||||
await self.llmState.setProcessing(false)
|
||||
}
|
||||
|
@ -342,6 +354,9 @@ final class LLMChatViewModel: ObservableObject {
|
|||
)
|
||||
|
||||
interactor.disconnect()
|
||||
|
||||
llm?.cancelInference()
|
||||
|
||||
llm = nil
|
||||
|
||||
FileOperationManager.shared.cleanTempDirectories()
|
||||
|
|
|
@ -83,7 +83,9 @@ struct LLMChatView: View {
|
|||
.disabled(viewModel.chatInputUnavilable)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .navigationBarLeading) {
|
||||
Button { presentationMode.wrappedValue.dismiss() } label: {
|
||||
Button {
|
||||
presentationMode.wrappedValue.dismiss()
|
||||
} label: {
|
||||
Image("backArrow", bundle: .current)
|
||||
}
|
||||
}
|
||||
|
@ -119,6 +121,10 @@ struct LLMChatView: View {
|
|||
viewModel.onStart()
|
||||
}
|
||||
.onDisappear(perform: viewModel.onStop)
|
||||
.onReceive(NotificationCenter.default.publisher(for: .dismissKeyboard)) { _ in
|
||||
// 隐藏键盘
|
||||
UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LLM Chat Message Builder
|
||||
|
|
|
@ -63,19 +63,33 @@ class ChatHistoryDatabase {
|
|||
for msg in message.attachments {
|
||||
if msg.type == .image {
|
||||
var imageUrl = msg.full
|
||||
|
||||
print("Processing image attachment: \(imageUrl.path)")
|
||||
|
||||
guard let copiedImage = ChatHistoryFileManager.shared.copyFile(from: imageUrl, for: historyId) else {
|
||||
print("Failed to copy image file: \(imageUrl.path)")
|
||||
continue
|
||||
}
|
||||
|
||||
imageUrl = copiedImage
|
||||
print("Image copied to: \(imageUrl.path)")
|
||||
|
||||
if imageUrl.isHEICImage() {
|
||||
guard let jpgUrl = AssetExtractor.convertHEICToJPG(heicUrl: imageUrl) else { continue }
|
||||
guard let jpgUrl = AssetExtractor.convertHEICToJPG(heicUrl: imageUrl) else {
|
||||
print("Failed to convert HEIC to JPG: \(imageUrl.path)")
|
||||
continue
|
||||
}
|
||||
imageUrl = jpgUrl
|
||||
print("HEIC converted to JPG: \(imageUrl.path)")
|
||||
}
|
||||
|
||||
copiedImages.append(LLMChatImage.init(id: msg.id, thumbnail: imageUrl, full: imageUrl))
|
||||
// 验证最终文件是否存在
|
||||
if ChatHistoryFileManager.shared.validateFileExists(at: imageUrl) {
|
||||
copiedImages.append(LLMChatImage.init(id: msg.id, thumbnail: imageUrl, full: imageUrl))
|
||||
print("Image successfully saved for history: \(imageUrl.path)")
|
||||
} else {
|
||||
print("Final image file validation failed: \(imageUrl.path)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +138,10 @@ class ChatHistoryDatabase {
|
|||
do {
|
||||
for history in try db.prepare(chatHistories) {
|
||||
let messagesData = history[messages].data(using: .utf8)!
|
||||
let historyMessages = try JSONDecoder().decode([HistoryMessage].self, from: messagesData)
|
||||
var historyMessages = try JSONDecoder().decode([HistoryMessage].self, from: messagesData)
|
||||
|
||||
// 验证并修复图片路径
|
||||
historyMessages = validateAndFixImagePaths(historyMessages, historyId: history[id])
|
||||
|
||||
let chatHistory = ChatHistory(
|
||||
id: history[id],
|
||||
|
@ -144,6 +161,38 @@ class ChatHistoryDatabase {
|
|||
return histories
|
||||
}
|
||||
|
||||
private func validateAndFixImagePaths(_ messages: [HistoryMessage], historyId: String) -> [HistoryMessage] {
|
||||
return messages.map { message in
|
||||
var updatedMessage = message
|
||||
|
||||
if let images = message.images {
|
||||
let validImages = images.compactMap { image -> LLMChatImage? in
|
||||
if ChatHistoryFileManager.shared.validateFileExists(at: image.full) {
|
||||
return image
|
||||
} else {
|
||||
let fileName = image.full.lastPathComponent
|
||||
if let validURL = ChatHistoryFileManager.shared.getValidFileURL(for: fileName, historyId: historyId) {
|
||||
return LLMChatImage(id: image.id, thumbnail: validURL, full: validURL)
|
||||
} else {
|
||||
print("Image file not found and cannot be recovered: \(image.full.path)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
updatedMessage = HistoryMessage(
|
||||
id: message.id,
|
||||
content: message.content,
|
||||
images: validImages.isEmpty ? nil : validImages,
|
||||
audio: message.audio,
|
||||
isUser: message.isUser,
|
||||
createdAt: message.createdAt
|
||||
)
|
||||
}
|
||||
|
||||
return updatedMessage
|
||||
}
|
||||
}
|
||||
|
||||
func deleteHistory(_ history: ChatHistory) {
|
||||
do {
|
||||
try db.run(chatHistories.filter(id == history.id).delete())
|
||||
|
@ -152,4 +201,4 @@ class ChatHistoryDatabase {
|
|||
print("Failed to delete history: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,10 +47,16 @@ class ChatHistoryFileManager {
|
|||
|
||||
// Check if the file already exists at the destination
|
||||
if FileManager.default.fileExists(atPath: destinationURL.path) {
|
||||
print("File already exists at \(destinationURL), returning original URL.")
|
||||
print("File already exists at \(destinationURL), returning existing URL.")
|
||||
return destinationURL
|
||||
}
|
||||
|
||||
// Check if source file exists before copying
|
||||
guard FileManager.default.fileExists(atPath: url.path) else {
|
||||
print("Source file does not exist at \(url.path)")
|
||||
return nil
|
||||
}
|
||||
|
||||
do {
|
||||
try FileManager.default.copyItem(at: url, to: destinationURL)
|
||||
print("File copied to \(destinationURL)")
|
||||
|
@ -61,6 +67,24 @@ class ChatHistoryFileManager {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Validate if a file exists at the given URL
|
||||
func validateFileExists(at url: URL) -> Bool {
|
||||
return FileManager.default.fileExists(atPath: url.path)
|
||||
}
|
||||
|
||||
// Get the correct file URL for a history file, checking if it exists
|
||||
func getValidFileURL(for fileName: String, historyId: String) -> URL? {
|
||||
let historyDirectory = baseDirectory.appendingPathComponent(historyId)
|
||||
let fileURL = historyDirectory.appendingPathComponent(fileName)
|
||||
|
||||
if FileManager.default.fileExists(atPath: fileURL.path) {
|
||||
return fileURL
|
||||
}
|
||||
|
||||
print("File not found at expected path: \(fileURL.path)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the history directory
|
||||
func deleteHistoryDirectory(for historyId: String) {
|
||||
let historyDirectory = baseDirectory.appendingPathComponent(historyId)
|
||||
|
|
|
@ -13,8 +13,8 @@ struct ChatHistoryItemView: View {
|
|||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
|
||||
if let firstMessage = history.messages.last {
|
||||
Text(String(firstMessage.content.prefix(200)))
|
||||
if let lastMessage = getLastNonEmptyMessage() {
|
||||
Text(String(lastMessage.content.prefix(200)))
|
||||
.lineLimit(1)
|
||||
.font(.system(size: 15, weight: .medium))
|
||||
.foregroundColor(.primary)
|
||||
|
@ -42,4 +42,13 @@ struct ChatHistoryItemView: View {
|
|||
.padding(.vertical, 10)
|
||||
.padding(.horizontal, 0)
|
||||
}
|
||||
|
||||
private func getLastNonEmptyMessage() -> HistoryMessage? {
|
||||
for message in history.messages.reversed() {
|
||||
if !message.content.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
return message
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,14 +70,79 @@
|
|||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <MNN/llm/llm.hpp>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <errno.h>
|
||||
#include "MNN/expr/ExecutorScope.hpp"
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "LLMInferenceEngineWrapper.h"
|
||||
|
||||
using namespace MNN::Transformer;
|
||||
// Conditional include for MNN headers
|
||||
#ifdef __has_include
|
||||
#if __has_include(<MNN/llm/llm.hpp>)
|
||||
#include <MNN/llm/llm.hpp>
|
||||
using namespace MNN::Transformer;
|
||||
#else
|
||||
// Fallback declarations when MNN headers are not available
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
class Llm {
|
||||
public:
|
||||
static Llm* createLLM(const std::string& config_path);
|
||||
virtual void set_config(const std::string& config) = 0;
|
||||
virtual void load() = 0;
|
||||
virtual void response(const std::string& input_str, std::ostream* os = nullptr, const char* end_with = nullptr) = 0;
|
||||
virtual void response(const std::vector<std::pair<std::string, std::string>>& history, std::ostream* os = nullptr, const char* end_with = nullptr, int max_new_tokens = 999999) = 0;
|
||||
virtual void response(const std::vector<int>& tokens, std::ostream* os = nullptr, const char* end_with = nullptr, int max_new_tokens = 999999) = 0;
|
||||
virtual void reset() = 0;
|
||||
virtual bool stopped() = 0;
|
||||
virtual int generate(int max_token_number = 0) = 0;
|
||||
struct LlmContext {
|
||||
int prompt_len;
|
||||
int gen_seq_len;
|
||||
int64_t prefill_us;
|
||||
int64_t decode_us;
|
||||
};
|
||||
virtual LlmContext* getContext() = 0;
|
||||
virtual ~Llm() = default;
|
||||
};
|
||||
}
|
||||
}
|
||||
using namespace MNN::Transformer;
|
||||
#endif
|
||||
#else
|
||||
// Fallback for older compilers
|
||||
namespace MNN {
|
||||
namespace Transformer {
|
||||
class Llm {
|
||||
public:
|
||||
static Llm* createLLM(const std::string& config_path);
|
||||
virtual void set_config(const std::string& config) = 0;
|
||||
virtual void load() = 0;
|
||||
virtual void response(const std::string& input_str, std::ostream* os = nullptr, const char* end_with = nullptr) = 0;
|
||||
virtual void response(const std::vector<std::pair<std::string, std::string>>& history, std::ostream* os = nullptr, const char* end_with = nullptr, int max_new_tokens = 512) = 0;
|
||||
virtual void response(const std::vector<int>& tokens, std::ostream* os = nullptr, const char* end_with = nullptr, int max_new_tokens = 512) = 0;
|
||||
virtual void reset() = 0;
|
||||
virtual bool stopped() = 0;
|
||||
virtual int generate(int max_token_number = 0) = 0;
|
||||
struct LlmContext {
|
||||
int prompt_len;
|
||||
int gen_seq_len;
|
||||
int64_t prefill_us;
|
||||
int64_t decode_us;
|
||||
};
|
||||
virtual LlmContext* getContext() = 0;
|
||||
virtual ~Llm() = default;
|
||||
};
|
||||
}
|
||||
}
|
||||
using namespace MNN::Transformer;
|
||||
#endif
|
||||
|
||||
using ChatMessage = std::pair<std::string, std::string>;
|
||||
|
||||
|
@ -288,12 +353,13 @@ private:
|
|||
};
|
||||
|
||||
@implementation LLMInferenceEngineWrapper {
|
||||
std::shared_ptr<Llm> _llm;
|
||||
std::shared_ptr<MNN::Transformer::Llm> _llm;
|
||||
std::vector<ChatMessage> _history;
|
||||
std::mutex _historyMutex;
|
||||
std::atomic<bool> _isProcessing;
|
||||
std::atomic<bool> _isBenchmarkRunning;
|
||||
std::atomic<bool> _shouldStopBenchmark;
|
||||
std::atomic<bool> _shouldStopInference;
|
||||
NSString *_modelPath;
|
||||
}
|
||||
|
||||
|
@ -314,6 +380,7 @@ private:
|
|||
_isProcessing = false;
|
||||
_isBenchmarkRunning = false;
|
||||
_shouldStopBenchmark = false;
|
||||
_shouldStopInference = false;
|
||||
|
||||
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^{
|
||||
BOOL success = [self loadModelFromPath:modelPath];
|
||||
|
@ -335,13 +402,22 @@ private:
|
|||
* @return true if successful, false otherwise
|
||||
*/
|
||||
bool remove_directory_safely(const std::string& path) {
|
||||
try {
|
||||
if (std::filesystem::exists(path)) {
|
||||
std::filesystem::remove_all(path);
|
||||
@try {
|
||||
NSString *pathStr = [NSString stringWithUTF8String:path.c_str()];
|
||||
NSFileManager *fileManager = [NSFileManager defaultManager];
|
||||
|
||||
if ([fileManager fileExistsAtPath:pathStr]) {
|
||||
NSError *error = nil;
|
||||
BOOL success = [fileManager removeItemAtPath:pathStr error:&error];
|
||||
if (!success && error) {
|
||||
NSLog(@"Error removing directory %s: %@", path.c_str(), error.localizedDescription);
|
||||
return false;
|
||||
}
|
||||
return success;
|
||||
}
|
||||
return true;
|
||||
} catch (const std::filesystem::filesystem_error& e) {
|
||||
NSLog(@"Error removing directory %s: %s", path.c_str(), e.what());
|
||||
} @catch (NSException *exception) {
|
||||
NSLog(@"Exception removing directory %s: %@", path.c_str(), exception.reason);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -394,7 +470,7 @@ bool remove_directory_safely(const std::string& path) {
|
|||
std::string model_dir = [bundleDirectory UTF8String];
|
||||
std::string config_path = model_dir + "/config.json";
|
||||
|
||||
_llm.reset(Llm::createLLM(config_path));
|
||||
_llm.reset(MNN::Transformer::Llm::createLLM(config_path));
|
||||
if (!_llm) {
|
||||
NSLog(@"Error: Failed to create LLM from bundle");
|
||||
return NO;
|
||||
|
@ -454,31 +530,40 @@ bool remove_directory_safely(const std::string& path) {
|
|||
return NO;
|
||||
}
|
||||
|
||||
MNN::BackendConfig backendConfig;
|
||||
auto executor = MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 1);
|
||||
MNN::Express::ExecutorScope s(executor);
|
||||
|
||||
// Get memory mapping setting with default fallback
|
||||
BOOL useMmap = configDict[@"use_mmap"] == nil ? YES : [configDict[@"use_mmap"] boolValue];
|
||||
|
||||
// Create LLM instance with error checking
|
||||
_llm.reset(Llm::createLLM(config_path));
|
||||
_llm.reset(MNN::Transformer::Llm::createLLM(config_path));
|
||||
if (!_llm) {
|
||||
NSLog(@"Error: Failed to create LLM instance from config: %s", config_path.c_str());
|
||||
return NO;
|
||||
}
|
||||
|
||||
// Setup temporary directory with improved error handling
|
||||
std::string model_path_str([modelPath UTF8String]);
|
||||
std::string temp_directory_path = model_path_str + "/temp";
|
||||
// Use iOS system temporary directory instead of model path (which is read-only in Bundle)
|
||||
NSString *tempDir = NSTemporaryDirectory();
|
||||
NSString *modelName = [[modelPath lastPathComponent] stringByDeletingPathExtension];
|
||||
NSString *tempDirPath = [tempDir stringByAppendingPathComponent:[NSString stringWithFormat:@"MNN_%@_temp", modelName]];
|
||||
std::string temp_directory_path = [tempDirPath UTF8String];
|
||||
|
||||
// Clean up existing temp directory
|
||||
if (!remove_directory_safely(temp_directory_path)) {
|
||||
NSLog(@"Warning: Failed to remove existing temp directory, continuing...");
|
||||
}
|
||||
|
||||
// Create new temp directory
|
||||
// Create new temp directory in system temp location
|
||||
if (mkdir(temp_directory_path.c_str(), 0755) != 0 && errno != EEXIST) {
|
||||
NSLog(@"Error: Failed to create temp directory: %s, errno: %d", temp_directory_path.c_str(), errno);
|
||||
return NO;
|
||||
}
|
||||
|
||||
NSLog(@"Created temp directory at: %s", temp_directory_path.c_str());
|
||||
|
||||
// Configure LLM with proper error handling
|
||||
bool useMmapCpp = (useMmap == YES);
|
||||
std::string configStr = "{\"tmp_path\":\"" + temp_directory_path + "\", \"use_mmap\":" + (useMmapCpp ? "true" : "false") + "}";
|
||||
|
@ -585,12 +670,21 @@ bool remove_directory_safely(const std::string& path) {
|
|||
|
||||
_isProcessing = true;
|
||||
|
||||
// Store reference for block execution
|
||||
LLMInferenceEngineWrapper *blockSelf = self;
|
||||
|
||||
// Use high priority queue for better responsiveness
|
||||
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^{
|
||||
// Check if object is still valid before proceeding
|
||||
if (!blockSelf || !blockSelf->_llm) {
|
||||
NSLog(@"LLMInferenceEngineWrapper was deallocated or model unloaded during inference");
|
||||
return;
|
||||
}
|
||||
|
||||
@try {
|
||||
auto inference_start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
OptimizedLlmStreamBuffer::CallBack callback = [output, self](const char* str, size_t len) {
|
||||
OptimizedLlmStreamBuffer::CallBack callback = [output](const char* str, size_t len) {
|
||||
if (output && str && len > 0) {
|
||||
@autoreleasepool {
|
||||
NSString *nsOutput = [[NSString alloc] initWithBytes:str
|
||||
|
@ -610,23 +704,77 @@ bool remove_directory_safely(const std::string& path) {
|
|||
|
||||
// Thread-safe history management
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(self->_historyMutex);
|
||||
self->_history.emplace_back(ChatMessage("user", [input UTF8String]));
|
||||
std::lock_guard<std::mutex> lock(blockSelf->_historyMutex);
|
||||
blockSelf->_history.emplace_back(ChatMessage("user", [input UTF8String]));
|
||||
}
|
||||
|
||||
std::string inputStr = [input UTF8String];
|
||||
if (inputStr == "benchmark") {
|
||||
[self performBenchmarkWithOutput:&os];
|
||||
[blockSelf performBenchmarkWithOutput:&os];
|
||||
} else {
|
||||
// Get initial context state for performance measurement
|
||||
auto context = self->_llm->getContext();
|
||||
auto context = blockSelf->_llm->getContext();
|
||||
int initial_prompt_len = context->prompt_len;
|
||||
int initial_decode_len = context->gen_seq_len;
|
||||
int64_t initial_prefill_time = context->prefill_us;
|
||||
int64_t initial_decode_time = context->decode_us;
|
||||
|
||||
// Execute inference
|
||||
self->_llm->response(self->_history, &os, "<eop>", 999999);
|
||||
// Reset stop flag before starting inference
|
||||
blockSelf->_shouldStopInference = false;
|
||||
|
||||
// Execute inference with enhanced stopped status checking
|
||||
@try {
|
||||
// Debug information for prompt
|
||||
std::string prompt_debug = "";
|
||||
for (const auto& msg : blockSelf->_history) {
|
||||
prompt_debug += msg.first + ": " + msg.second + "\n";
|
||||
}
|
||||
NSLog(@"submitNative prompt_string_for_debug:\n%s\nmax_new_tokens_: %d", prompt_debug.c_str(), 999999);
|
||||
|
||||
// Start inference with initial response processing
|
||||
blockSelf->_llm->response(blockSelf->_history, &os, "<eop>", 1);
|
||||
int current_size = 1;
|
||||
int max_new_tokens = 999999;
|
||||
|
||||
// Continue generation with precise token-by-token control
|
||||
while (!blockSelf->_shouldStopInference.load() &&
|
||||
!blockSelf->_llm->stoped() &&
|
||||
current_size < max_new_tokens) {
|
||||
|
||||
// Generate single token for maximum control
|
||||
blockSelf->_llm->generate(1);
|
||||
current_size++;
|
||||
|
||||
// Small delay to allow UI updates and stop signal processing
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
|
||||
// Send appropriate end signal based on stop reason
|
||||
if (output) {
|
||||
dispatch_async(dispatch_get_main_queue(), ^{
|
||||
if (blockSelf->_shouldStopInference.load()) {
|
||||
output(@"<stopped>");
|
||||
} else {
|
||||
output(@"<eop>");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
NSLog(@"Inference completed. Generated tokens: %d, Stopped by user: %s, Model stopped: %s",
|
||||
current_size,
|
||||
blockSelf->_shouldStopInference.load() ? "YES" : "NO",
|
||||
blockSelf->_llm->stoped() ? "YES" : "NO");
|
||||
|
||||
} @catch (NSException *exception) {
|
||||
NSLog(@"Exception during response generation: %@", exception.reason);
|
||||
|
||||
// Send end signal even on error to unlock UI
|
||||
if (output) {
|
||||
dispatch_async(dispatch_get_main_queue(), ^{
|
||||
output(@"<eop>");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate performance metrics if requested
|
||||
if (showPerformance) {
|
||||
|
@ -683,7 +831,7 @@ bool remove_directory_safely(const std::string& path) {
|
|||
}
|
||||
}
|
||||
@finally {
|
||||
self->_isProcessing = false;
|
||||
blockSelf->_isProcessing = false;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -770,17 +918,25 @@ bool remove_directory_safely(const std::string& path) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Enhanced deallocation with proper cleanup
|
||||
* Enhanced deallocation with proper cleanup and timeout
|
||||
*/
|
||||
- (void)dealloc {
|
||||
NSLog(@"LLMInferenceEngineWrapper deallocating...");
|
||||
|
||||
// Stop any running benchmark
|
||||
_shouldStopBenchmark = true;
|
||||
// Actively cancel all operations first
|
||||
[self cancelInference];
|
||||
|
||||
// Wait for any ongoing processing to complete
|
||||
while (_isProcessing.load() || _isBenchmarkRunning.load()) {
|
||||
// Wait for any ongoing processing to complete with timeout
|
||||
int timeout = 100; // 1 second timeout (100 * 10ms)
|
||||
while ((_isProcessing.load() || _isBenchmarkRunning.load()) && timeout > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
timeout--;
|
||||
}
|
||||
|
||||
if (timeout <= 0) {
|
||||
NSLog(@"Warning: Dealloc timeout, forcing cleanup");
|
||||
_isProcessing = false;
|
||||
_isBenchmarkRunning = false;
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -879,11 +1035,17 @@ bool remove_directory_safely(const std::string& path) {
|
|||
* Cancel ongoing inference (if supported)
|
||||
*/
|
||||
- (void)cancelInference {
|
||||
if (_isProcessing.load()) {
|
||||
NSLog(@"Inference cancellation requested");
|
||||
// Note: Actual cancellation depends on MNN LLM implementation
|
||||
// This is a placeholder for future enhancement
|
||||
}
|
||||
NSLog(@"Cancelling inference...");
|
||||
|
||||
// Set all stop flags to true
|
||||
_shouldStopInference = true;
|
||||
_shouldStopBenchmark = true;
|
||||
|
||||
// Force set processing states to false for immediate cleanup
|
||||
_isProcessing = false;
|
||||
_isBenchmarkRunning = false;
|
||||
|
||||
NSLog(@"Inference cancellation completed - all flags set");
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1041,7 +1041,7 @@
|
|||
}
|
||||
},
|
||||
"tag.deepThinking" : {
|
||||
"extractionState" : "stale",
|
||||
"comment" : "Deep thinking tag for local model",
|
||||
"localizations" : {
|
||||
"en" : {
|
||||
"stringUnit" : {
|
||||
|
@ -1091,6 +1091,23 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"tag.localModel" : {
|
||||
"comment" : "Local model inside the app",
|
||||
"localizations" : {
|
||||
"en" : {
|
||||
"stringUnit" : {
|
||||
"state" : "translated",
|
||||
"value" : "Built-in Model"
|
||||
}
|
||||
},
|
||||
"zh-Hans" : {
|
||||
"stringUnit" : {
|
||||
"state" : "translated",
|
||||
"value" : "内置模型"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tag.math" : {
|
||||
"extractionState" : "stale",
|
||||
"localizations" : {
|
||||
|
@ -1292,6 +1309,16 @@
|
|||
},
|
||||
"Yes" : {
|
||||
|
||||
},
|
||||
"搜索本地模型..." : {
|
||||
"localizations" : {
|
||||
"en" : {
|
||||
"stringUnit" : {
|
||||
"state" : "translated",
|
||||
"value" : "Search local models …"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"version" : "1.0"
|
|
@ -9,10 +9,25 @@ import SwiftUI
|
|||
|
||||
struct LocalModelListView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
@State private var localSearchText = ""
|
||||
|
||||
private var filteredLocalModels: [ModelInfo] {
|
||||
let downloadedModels = viewModel.models.filter { $0.isDownloaded }
|
||||
|
||||
if localSearchText.isEmpty {
|
||||
return downloadedModels
|
||||
} else {
|
||||
return downloadedModels.filter { model in
|
||||
model.id.localizedCaseInsensitiveContains(localSearchText) ||
|
||||
model.modelName.localizedCaseInsensitiveContains(localSearchText) ||
|
||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(localSearchText) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
List {
|
||||
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.id) { model in
|
||||
ForEach(filteredLocalModels, id: \.id) { model in
|
||||
Button(action: {
|
||||
viewModel.selectModel(model)
|
||||
}) {
|
||||
|
@ -25,6 +40,7 @@ struct LocalModelListView: View {
|
|||
}
|
||||
}
|
||||
.listStyle(.plain)
|
||||
.searchable(text: $localSearchText, prompt: "搜索本地模型...")
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import SwiftUI
|
||||
|
||||
// MainTabView is the primary view of the app, containing the tab bar and navigation for main sections.
|
||||
struct MainTabView: View {
|
||||
// MARK: - State Properties
|
||||
|
||||
|
|
|
@ -85,8 +85,34 @@ struct ModelInfo: Codable {
|
|||
// MARK: - File System & Path Management
|
||||
|
||||
var localPath: String {
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
|
||||
// Check if this is a local model from LocalModel folder or Bundle root
|
||||
if let sources = sources, let localSource = sources["local"] {
|
||||
guard let bundlePath = Bundle.main.resourcePath else {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if this is a flattened model (files directly in Bundle root)
|
||||
if localSource.hasPrefix("bundle_root/") {
|
||||
// For flattened models, return the Bundle root path
|
||||
// The model files are directly in the Bundle root directory
|
||||
return bundlePath
|
||||
} else {
|
||||
// Original LocalModel folder structure
|
||||
let localModelPath = (bundlePath as NSString).deletingLastPathComponent + "/LocalModel"
|
||||
|
||||
// If modelName is "LocalModel", return the LocalModel folder directly
|
||||
if modelName == "LocalModel" {
|
||||
return localModelPath
|
||||
} else {
|
||||
// For subdirectory models, append the model name
|
||||
return (localModelPath as NSString).appendingPathComponent(modelName)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For downloaded models, use the original Hub API path
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Size Calculation & Formatting
|
||||
|
|
|
@ -11,7 +11,6 @@ import SwiftUI
|
|||
class ModelListViewModel: ObservableObject {
|
||||
// MARK: - Published Properties
|
||||
@Published var models: [ModelInfo] = []
|
||||
@Published var searchText = ""
|
||||
@Published var quickFilterTags: [String] = []
|
||||
@Published var selectedModel: ModelInfo?
|
||||
@Published var showError = false
|
||||
|
@ -44,19 +43,6 @@ class ModelListViewModel: ObservableObject {
|
|||
Array(Set(models.compactMap { $0.vendor }))
|
||||
}
|
||||
|
||||
var filteredModels: [ModelInfo] {
|
||||
let filtered = searchText.isEmpty ? models : models.filter { model in
|
||||
model.id.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||
}
|
||||
|
||||
let downloaded = filtered.filter { $0.isDownloaded }
|
||||
let notDownloaded = filtered.filter { !$0.isDownloaded }
|
||||
|
||||
return downloaded + notDownloaded
|
||||
}
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
init() {
|
||||
|
@ -67,6 +53,118 @@ class ModelListViewModel: ObservableObject {
|
|||
|
||||
// MARK: - Model Data Management
|
||||
|
||||
/// Load models from Bundle root directory (LocalModel folder structure flattened)
|
||||
private func loadLocalModels() async -> [ModelInfo] {
|
||||
let fileManager = FileManager.default
|
||||
var localModels: [ModelInfo] = []
|
||||
|
||||
guard let resourcePath = Bundle.main.resourcePath else {
|
||||
return localModels
|
||||
}
|
||||
|
||||
do {
|
||||
let contents = try fileManager.contentsOfDirectory(atPath: resourcePath)
|
||||
|
||||
// Check if we have model files directly in Bundle root
|
||||
let modelFiles = ["config.json", "llm_config.json", "llm.mnn", "tokenizer.txt"]
|
||||
let foundModelFiles = contents.filter { modelFiles.contains($0) }
|
||||
|
||||
if !foundModelFiles.isEmpty {
|
||||
// Check if we have a complete model (at least config.json)
|
||||
if foundModelFiles.contains("config.json") {
|
||||
let modelName = "Qwen3-0.6B-MNN"
|
||||
let localModel = ModelInfo(
|
||||
modelName: modelName,
|
||||
tags: [NSLocalizedString("tag.deepThinking", comment: "Deep thinking tag for local model"),
|
||||
NSLocalizedString("tag.localModel", comment: "Local model inside the app")],
|
||||
categories: ["Local Models"],
|
||||
vendor: "Local",
|
||||
sources: ["local": "bundle_root/\(modelName)"],
|
||||
isDownloaded: true
|
||||
)
|
||||
localModels.append(localModel)
|
||||
|
||||
ModelStorageManager.shared.markModelAsDownloaded(modelName)
|
||||
}
|
||||
} else {
|
||||
// Fallback: try to find LocalModel folder
|
||||
let localModelPath = (resourcePath as NSString).appendingPathComponent("LocalModel")
|
||||
var isDirectory: ObjCBool = false
|
||||
|
||||
if fileManager.fileExists(atPath: localModelPath, isDirectory: &isDirectory), isDirectory.boolValue {
|
||||
localModels.append(contentsOf: await processLocalModelFolder(at: localModelPath))
|
||||
}
|
||||
}
|
||||
|
||||
} catch {
|
||||
// Silently handle error
|
||||
}
|
||||
|
||||
return localModels
|
||||
}
|
||||
|
||||
/// Process LocalModel folder (fallback for non-flattened structure)
|
||||
private func processLocalModelFolder(at validPath: String) async -> [ModelInfo] {
|
||||
let fileManager = FileManager.default
|
||||
var localModels: [ModelInfo] = []
|
||||
|
||||
// Check if this is a valid model directory (contains config.json)
|
||||
let configPath = (validPath as NSString).appendingPathComponent("config.json")
|
||||
if fileManager.fileExists(atPath: configPath) {
|
||||
let modelName = "LocalModel"
|
||||
let localModel = ModelInfo(
|
||||
modelName: modelName,
|
||||
tags: ["local", "bundled"],
|
||||
categories: ["Local Models"],
|
||||
vendor: "Local",
|
||||
sources: ["local": "local/\(modelName)"],
|
||||
isDownloaded: true
|
||||
)
|
||||
localModels.append(localModel)
|
||||
|
||||
ModelStorageManager.shared.markModelAsDownloaded(modelName)
|
||||
} else {
|
||||
// Check for subdirectories that might contain models
|
||||
do {
|
||||
let contents = try fileManager.contentsOfDirectory(atPath: validPath)
|
||||
|
||||
for item in contents {
|
||||
// Skip hidden files and common non-model files
|
||||
if item.hasPrefix(".") || item == "bench.txt" {
|
||||
continue
|
||||
}
|
||||
|
||||
let itemPath = (validPath as NSString).appendingPathComponent(item)
|
||||
var isItemDirectory: ObjCBool = false
|
||||
|
||||
if fileManager.fileExists(atPath: itemPath, isDirectory: &isItemDirectory),
|
||||
isItemDirectory.boolValue {
|
||||
|
||||
let itemConfigPath = (itemPath as NSString).appendingPathComponent("config.json")
|
||||
|
||||
if fileManager.fileExists(atPath: itemConfigPath) {
|
||||
let localModel = ModelInfo(
|
||||
modelName: item,
|
||||
tags: ["local", "bundled"],
|
||||
categories: ["Local Models"],
|
||||
vendor: "Local",
|
||||
sources: ["local": "local/\(item)"],
|
||||
isDownloaded: true
|
||||
)
|
||||
localModels.append(localModel)
|
||||
|
||||
ModelStorageManager.shared.markModelAsDownloaded(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Silently handle error
|
||||
}
|
||||
}
|
||||
|
||||
return localModels
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func fetchModels() async {
|
||||
do {
|
||||
|
@ -77,6 +175,10 @@ class ModelListViewModel: ObservableObject {
|
|||
|
||||
var fetchedModels = info.models
|
||||
|
||||
// Add LocalModel folder models
|
||||
let localModels = await loadLocalModels()
|
||||
fetchedModels.append(contentsOf: localModels)
|
||||
|
||||
filterDiffusionModels(fetchedModels: &fetchedModels)
|
||||
loadCachedSizes(for: &fetchedModels)
|
||||
sortModels(fetchedModels: &fetchedModels)
|
||||
|
|
|
@ -28,9 +28,6 @@ struct ModelListView: View {
|
|||
}
|
||||
}
|
||||
.searchable(text: $searchText, prompt: "Search models...")
|
||||
.onChange(of: searchText) { _, newValue in
|
||||
viewModel.searchText = newValue
|
||||
}
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
|
@ -102,21 +99,31 @@ struct ModelListView: View {
|
|||
}
|
||||
}
|
||||
|
||||
// Filter models based on selected tags, categories and vendors
|
||||
private var filteredModels: [ModelInfo] {
|
||||
let baseFiltered = viewModel.filteredModels
|
||||
|
||||
let searchFiltered = searchText.isEmpty ? viewModel.models : viewModel.models.filter { model in
|
||||
model.id.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||
}
|
||||
|
||||
let tagFiltered: [ModelInfo]
|
||||
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
|
||||
return baseFiltered
|
||||
tagFiltered = searchFiltered
|
||||
} else {
|
||||
tagFiltered = searchFiltered.filter { model in
|
||||
let tagMatch = checkTagMatch(model: model)
|
||||
let categoryMatch = checkCategoryMatch(model: model)
|
||||
let vendorMatch = checkVendorMatch(model: model)
|
||||
|
||||
return tagMatch && categoryMatch && vendorMatch
|
||||
}
|
||||
}
|
||||
|
||||
return baseFiltered.filter { model in
|
||||
let tagMatch = checkTagMatch(model: model)
|
||||
let categoryMatch = checkCategoryMatch(model: model)
|
||||
let vendorMatch = checkVendorMatch(model: model)
|
||||
|
||||
return tagMatch && categoryMatch && vendorMatch
|
||||
}
|
||||
let downloaded = tagFiltered.filter { $0.isDownloaded }
|
||||
let notDownloaded = tagFiltered.filter { !$0.isDownloaded }
|
||||
|
||||
return downloaded + notDownloaded
|
||||
}
|
||||
|
||||
// Extract tag matching logic as independent method
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 操作按钮视图
|
||||
struct ActionButtonsView: View {
|
||||
let model: ModelInfo
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 已下载按钮视图
|
||||
struct DownloadedButtonView: View {
|
||||
@Binding var showDeleteAlert: Bool
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 下载中按钮视图
|
||||
struct DownloadingButtonView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
let downloadProgress: Double
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 标签芯片
|
||||
struct TagChip: View {
|
||||
let text: String
|
||||
|
||||
|
@ -22,4 +21,4 @@ struct TagChip: View {
|
|||
.stroke(Color.secondary.opacity(0.3), lineWidth: 0.5)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - 标签视图
|
||||
struct TagsView: View {
|
||||
let tags: [String]
|
||||
|
||||
|
@ -22,4 +21,4 @@ struct TagsView: View {
|
|||
}
|
||||
.frame(height: 25)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,12 @@ struct SwipeActionsView: View {
|
|||
let model: ModelInfo
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
||||
private func isBuiltInLocalModel(_ model: ModelInfo) -> Bool {
|
||||
guard let vendor = model.vendor, vendor == "Local" else { return false }
|
||||
guard let sources = model.sources, let localSource = sources["local"] else { return false }
|
||||
return localSource.hasPrefix("bundle_root/")
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
if viewModel.pinnedModelIds.contains(model.id) {
|
||||
Button {
|
||||
|
@ -26,7 +32,7 @@ struct SwipeActionsView: View {
|
|||
Label(LocalizedStringKey("button.pin"), systemImage: "pin")
|
||||
}.tint(.primaryBlue)
|
||||
}
|
||||
if model.isDownloaded {
|
||||
if model.isDownloaded && !isBuiltInLocalModel(model) {
|
||||
Button(role: .destructive) {
|
||||
Task {
|
||||
await viewModel.deleteModel(model)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
//
|
||||
// NotificationNames.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/8/1.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
extension Notification.Name {
|
||||
|
||||
static let dismissKeyboard = Notification.Name("dismissKeyboard")
|
||||
|
||||
static let messageSent = Notification.Name("messageSent")
|
||||
|
||||
static let onScrollToBottom = Notification.Name("onScrollToBottom")
|
||||
}
|
|
@ -154,6 +154,11 @@ extension Date {
|
|||
if result.contains("second") {
|
||||
return "Just now"
|
||||
}
|
||||
|
||||
if result.contains("秒钟") {
|
||||
return "刚刚"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,48 +111,23 @@ iPhone 因为内存有限,建议使用7B以及以下的模型,避免内存
|
|||
|
||||
## 本地调试
|
||||
|
||||
如果我们希望直接电脑下载模型,不通过App内下载模型,进行调试,可以通过一下的方式。
|
||||
本地调试模型非常简单,只需要将模型文件拖动到LocalModel文件夹下,然后运行项目即可:
|
||||
|
||||
1. 首先在 [huggingface](https://huggingface.co/taobao-mnn) 或者 [modelscope](https://modelscope.cn/organization/MNN) 下载 MNN 相关的模型
|
||||
|
||||
<img width="400" alt="image" src="./assets/copyLocalModel.png" />
|
||||
|
||||
|
||||
2. 将下载之后的模型文件夹内的所有文件,拖动到项目中 LocalModel 文件夹下:
|
||||
|
||||
<img width="200" alt="image" src="./assets/copyLocalModel2.png" />
|
||||
|
||||
|
||||
3. 确保以上文件都已经在 copy bundle resources 中
|
||||
|
||||
|
||||
<img width="400" alt="image" src="./assets/copyLocalMode3.png" />
|
||||
|
||||
4. 运行项目,点击进入聊天对话页面,进行模型对话和调试。
|
||||
|
||||
4. 注释下载相关代码
|
||||
|
||||
```Swift
|
||||
/*
|
||||
try await modelClient.downloadModel(model: model) { progress in
|
||||
Task { @MainActor in
|
||||
DispatchQueue.main.async {
|
||||
self.downloadProgress[model.modelId] = progress
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
```
|
||||
5. 修改模型加载方式
|
||||
|
||||
在 LLMInferenceEngineWrapper 类中修改:
|
||||
|
||||
```Swift
|
||||
// BOOL success = [self loadModelFromPath:modelPath];
|
||||
// MARK: Test Local Model
|
||||
BOOL success = [self loadModel];
|
||||
```
|
||||
|
||||
6. 运行项目,点击进入聊天对话页面,进行模型对话和调试。
|
||||
应用会自动检测并加载LocalModel文件夹中的模型,无需额外配置。
|
||||
|
||||
|
||||
## Release Notes
|
||||
|
|
|
@ -125,14 +125,12 @@ Here is the professional technical translation of the provided text:
|
|||
|
||||
## Local Debugging
|
||||
|
||||
If we want to directly download the models to the computer for debugging without downloading them through the app, we can follow these steps:
|
||||
For local debugging, simply drag the model files to the LocalModel folder and run the project:
|
||||
|
||||
1. First, download the MNN-related models from [Hugging Face](https://huggingface.co/taobao-mnn) or [Modelscope](https://modelscope.cn/organization/MNN):
|
||||
|
||||
|
||||
<img width="400" alt="image" src="./assets/copyLocalModel.png" />
|
||||
|
||||
|
||||
2. After downloading, drag all the files from the model folder into the project's LocalModel folder:
|
||||
|
||||
<img width="300" alt="image" src="./assets/copyLocalModel2.png" />
|
||||
|
@ -141,33 +139,9 @@ If we want to directly download the models to the computer for debugging without
|
|||
|
||||
<img width="400" alt="image" src="./assets/copyLocalMode3.png" />
|
||||
|
||||
4. Comment out the model download code:
|
||||
4. Run the project, navigate to the chat page, and perform model interactions and debugging.
|
||||
|
||||
```Swift
|
||||
/*
|
||||
try await modelClient.downloadModel(model: model) { progress in
|
||||
Task { @MainActor in
|
||||
DispatchQueue.main.async {
|
||||
self.downloadProgress[model.modelId] = progress
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
```
|
||||
|
||||
|
||||
5. Modify the Model Loading Method
|
||||
|
||||
Modify the `LLMInferenceEngineWrapper` class:
|
||||
|
||||
```Swift
|
||||
// BOOL success = [self loadModelFromPath:modelPath];
|
||||
// MARK: Test Local Model
|
||||
BOOL success = [self loadModel];
|
||||
```
|
||||
|
||||
|
||||
6. Run the project, navigate to the chat page, and perform model interactions and debugging.
|
||||
The app will automatically detect and load models from the LocalModel folder without requiring additional configuration.
|
||||
|
||||
## Release Notes
|
||||
|
||||
|
|
|
@ -16,11 +16,11 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_BUILD_QUANTOOLS | 是否构建MNN的量化工具,默认为`OFF` |
|
||||
| MNN_EVALUATION | 是否构建MNN的评估工具,默认为`OFF` |
|
||||
| MNN_BUILD_CONVERTER | 是否构建MNN的转换工具,默认为`OFF` |
|
||||
| MNN_SUPPORT_QUNAT_EXTEND | 是否编译非核心算子的量化版本,默认为`ON` |
|
||||
| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子等已经废弃的算子,用于兼容历史模型(1.1.0版本之前),默认为`OFF` |
|
||||
| MNN_SUPPORT_DEPRECATED_OPV2 | 是否编译MNN更新到3.0之后已经废弃的算子,用于兼容历史模型(3.0.0版本之前),比如 Convolution3D 和 ConvTranspose3D在3.0.0 版本之后改由模型转换器转化为对应2D算子,不再需要运行时支持,默认为`ON` |
|
||||
| MNN_REDUCE_SIZE | 是否裁剪MNN库大小,去除求导相关算子,减少优化策略,默认为`OFF` ,开启时,MNN_SUPPORT_QUANT_EXTEND / MNN_SUPPORT_DEPRECATED_OP / MNN_SUPPORT_DEPRECATED_OPV2 都会设成 OFF|
|
||||
| MNN_SUPPORT_QUANT_EXTEND | 是否开启Binary/Unary等算子的量化计算支持,默认为`ON` |
|
||||
| MNN_SUPPORT_QUANT_EXTEND | 是否编译非核心算子的量化版本,默认为`ON` |
|
||||
| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子等已经废弃的算子,用于兼容历史模型(1.1.0版本之前),默认为`OFF` |
|
||||
| MNN_SUPPORT_DEPRECATED_OPV2 | 是否编译MNN更新到3.0之后已经废弃的算子,用于兼容历史模型(3.0.0版本之前),比如 `Convolution3D` / `ConvTranspose3D` 在3.0.0 版本之后改由模型转换器转化为对应2D算子,不再需要运行时支持,默认为`ON` |
|
||||
| MNN_REDUCE_SIZE | 是否裁剪MNN库大小,去除求导相关算子,减少优化策略,默认为`OFF` ,开启时,`MNN_SUPPORT_QUANT_EXTEND` / `MNN_SUPPORT_DEPRECATED_OP` / `MNN_SUPPORT_DEPRECATED_OPV2` / `MNN_USE_SPARSE_COMPUTE` 都会设成 `OFF` |
|
||||
| MNN_SUPPORT_QUANT_EXTEND | 是否开启Binary/Unary等算子的量化计算支持,默认为`ON` |
|
||||
| MNN_DEBUG_MEMORY | 是否开启MNN内存调试,默认为`OFF` |
|
||||
| MNN_DEBUG_TENSOR_SIZE | 是否开启MNN tensor size调试,默认为`OFF` |
|
||||
| MNN_GPU_TRACE | 是否开启MNN GPU调试,默认为`OFF` |
|
||||
|
@ -30,7 +30,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_AAPL_FMWK | 是否构建`MNN.framework`替代`*.dylib`,默认为`OFF` |
|
||||
| MNN_WITH_PLUGIN | 是否支持`Plugin算子`,默认为`OFF` |
|
||||
| MNN_SKIPBUILD_GEOMETRY | 是否跳过MNN的几何计算编译,若是,MNN引擎仅支持在模型转换工具时加上 --saveStaticModel 转换出来的固定输入形状的模型,默认为`OFF` |
|
||||
| MNN_BUILD_MINI | 是否构建MNN的最小化版本,若是,开启 MNN_SKIPBUILD_GEOMETRY 和 MNN_REDUCE_SIZE,默认为`OFF` |
|
||||
| MNN_BUILD_MINI | 是否构建MNN的最小化版本,若是,开启 `MNN_SKIPBUILD_GEOMETRY` 和 `MNN_REDUCE_SIZE`,默认为`OFF` |
|
||||
| MNN_USE_SSE | 在x86上是否使用SSE指令集,默认为`OFF` |
|
||||
| MNN_BUILD_CODEGEN | 是否构建MNN的代码生成部分,该功能提供了算子融合与代码生成能力,为实验性功能,默认为`OFF` |
|
||||
| MNN_ENABLE_COVERAGE | 是否开启MNN的代码覆盖率,默认为`OFF` |
|
||||
|
@ -59,10 +59,12 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_NNAPI | 是否构建`NNAPI`后端,默认为`OFF` |
|
||||
| MNN_QNN | 是否构建`QNN`后端,默认为`OFF` |
|
||||
| MNN_QNN_CONVERT_MODE | 在`MNN_QNN`开启的基础上,是否构建Convert模式的QNN后端,默认为`OFF` |
|
||||
| MNN_NPU | 是否构建`NPU`后端,默认为`OFF` |
|
||||
| MNN_USE_SPARSE_COMPUTE | 是否使用稀疏计算,默认为`ON` |
|
||||
| MNN_BUILD_BENCHMARK | 是否构建MNN的性能测试,默认为`OFF` |
|
||||
| MNN_BUILD_TEST | 是否构建MNN的单元测试,默认为`OFF` |
|
||||
| MNN_BUILD_FOR_ANDROID_COMMAND | 是否使用命令行构建`Android`,默认为`OFF` |
|
||||
| MNN_USE_LOGCAT | 是否使用`logcat`代替`printf`输出日志,默认为`OFF` |
|
||||
| MNN_USE_LOGCAT | 是否使用`logcat`代替`printf`输出日志,默认为`ON` |
|
||||
| MNN_USE_CPP11 | 是否使用`C++11`编译MNN,默认为`ON` |
|
||||
| MNN_SUPPORT_BF16 | 是否支持`BF16`,默认为`OFF` |
|
||||
| MNN_SSE_USE_FP16_INSTEAD | 在X86平台是否使用`FP16`替代`BF16`,默认为`OFF` |
|
||||
|
@ -72,7 +74,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_METALLIB_SOURCE | 使用Metal时是否直接使用Metal源码,该宏仅在`MNN_METAL=ON`时生效,默认为`ON` |
|
||||
| MNN_VULKAN_DEBUG | 是否打开Vulkan的DEBUG模式,该宏仅在`MNN_VULKAN=ON`时生效,默认为`OFF` |
|
||||
| MNN_OPENGL_REGEN | 是否重新生成OpenGL Kenel,该宏仅在`MNN_OPENGL=ON`时生效,默认为`OFF` |
|
||||
| MNN_TRT_DYNAMIC | 是否通过dlopen的方式引入TRT的动态库,该宏仅在`MNN_TENSORRT=ON`时生效,默认为`OFF |
|
||||
| MNN_TRT_DYNAMIC | 是否通过dlopen的方式引入TRT的动态库,该宏仅在`MNN_TENSORRT=ON`时生效,默认为`OFF` |
|
||||
| MNN_BUILD_TORCH | 构建的`MNNConvert`是否支持`TorchScript`,该宏仅在`MNN_BUILD_CONVERTER=ON`时生效,默认为`OFF` |
|
||||
| MNN_TRAIN_DEBUG | 构建的训练模块是否支持调试,该宏仅在`MNN_BUILD_TRAIN=ON`时生效,默认为`OFF` |
|
||||
| MNN_USE_OPENCV | 构建的训练Demo是否使用`OpenCV`依赖,该宏仅在`MNN_BUILD_TRAIN=ON`时生效,默认为`OFF` |
|
||||
|
|
|
@ -56,11 +56,11 @@
|
|||
- `llm_bench` 大语言模型测评工具
|
||||
## 测试工具
|
||||
- 相关编译选项
|
||||
- `MNN_BUILD_TOOL` 是否编译测试工具
|
||||
- `MNN_BUILD_TOOLS` 是否编译测试工具
|
||||
- 编译命令
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DMNN_BUILD_TOOL=ON
|
||||
cmake .. -DMNN_BUILD_TOOLS=ON
|
||||
make -j4
|
||||
```
|
||||
- 编译产物
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
## 本地安装
|
||||
```bash
|
||||
cd /path/to/MNN/pymnn/pip_package
|
||||
python build_deps.py {MNN依赖包组合} #internal,cuda,trt,cuda_tune,opencl,vulkan,render,no_sse,torch这几个字符串的任意组合,例如字符串可为:"cuda,reder,no_sse"
|
||||
python build_deps.py {MNN依赖包组合} #internal,cuda,trt,cuda_tune,opencl,vulkan,render,no_sse,torch,openmp,llm这几个字符串的任意组合,例如字符串可为:"cuda,reder,no_sse"
|
||||
python setup.py install --version {MNN版本} --deps {MNN依赖包组合}
|
||||
```
|
||||
## 构建Python Wheel包
|
||||
|
|
|
@ -92,7 +92,8 @@ rtmgr->setMode(Interpreter::Session_Debug);
|
|||
|
||||
- Interpreter::HintMode::WINOGRAD_MEMORY_LEVEL :使用 Winograd 算法优化卷积时,内存占用倾向,默认为 3 ,若希望降低内存占用可设为 0
|
||||
- Interpreter::HintMode::GEOMETRY_COMPUTE_MASK :几何计算相关优化开关,1为区域合并,2为复合区域合并,4为使用loop算子,8为支持几何计算重计算,需要多个功能开启时把对应值叠加。默认为功能全开。
|
||||
- Interpreter::HintMode::CPU_LITTLECORE_DECREASE_RATE :对于 Android 设备存在大中小核的情况,大核算力到中核算力的衰减比例。默认为50(中核算力为大核的50%)
|
||||
- Interpreter::HintMode::CPU_LITTLECORE_DECREASE_RATE :对于 Android 设备存在大中小核的情况,设置大核与小核之间的算力衰减比例,用于任务调度。默认值为50,表示小核的算力是大核的50%。MNN会根据这个比例来决定在大小核上分配的计算任务量。这个参数**并不直接绑定**线程到特定核心,而是影响任务分配策略。
|
||||
- Interpreter::HintMode::CPU_CORE_IDS :直接将MNN的计算任务绑定到指定的CPU核心上。这是一个更强力的控制方式,可以精确控制MNN使用的CPU资源。详细用法请参考 [Session API使用 - CPU 核心绑定](../inference/session.md#cpu-核心绑定)。
|
||||
|
||||
|
||||
#### ExternalPath
|
||||
|
|
|
@ -47,7 +47,7 @@ import MNN.cv as cv
|
|||
import MNN.numpy as np
|
||||
import MNN.expr as expr
|
||||
|
||||
# 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍
|
||||
# 配置执行后端,线程数,精度等信息;key-value请查看API介绍
|
||||
config = {}
|
||||
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
|
||||
config['backend'] = 0 # CPU
|
||||
|
|
|
@ -200,6 +200,39 @@ struct BackendConfig {
|
|||
|
||||
`sharedContext`用于自定义后端,用户可以根据自身需要赋值。
|
||||
|
||||
#### CPU 核心绑定
|
||||
MNN 支持将计算任务绑定到指定的 CPU 核心上执行。这对于需要精细控制 CPU 资源,避免线程在不同核心之间切换,或者希望将 MNN 的计算任务限制在特定核心上,以减少对其他应用干扰的场景非常有用。
|
||||
|
||||
通过 `Interpreter::setSessionHint` 方法,并使用 `HintMode::CPU_CORE_IDS`,可以指定一个或多个 CPU 核心的 ID。
|
||||
|
||||
```cpp
|
||||
#include <MNN/Interpreter.hpp>
|
||||
|
||||
// ...
|
||||
|
||||
// 创建 Interpreter
|
||||
auto interpreter = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile("your_model.mnn"));
|
||||
|
||||
// 设置 CPU 核心绑定
|
||||
// 假设我们希望将计算任务绑定到 0 号和 1 号 CPU 核心
|
||||
std::vector<int> cpu_ids = {0, 1};
|
||||
interpreter->setSessionHint(MNN::Interpreter::HintMode::CPU_CORE_IDS, cpu_ids.data(), cpu_ids.size());
|
||||
|
||||
// 创建 Session
|
||||
MNN::ScheduleConfig config;
|
||||
config.type = MNN_FORWARD_CPU;
|
||||
config.numThread = 2; // 线程数最好和绑定的核心数一致
|
||||
auto session = interpreter->createSession(config);
|
||||
|
||||
// ... 运行推理
|
||||
```
|
||||
|
||||
**注意事项:**
|
||||
|
||||
* `CPU_CORE_IDS` 的设置必须在 `createSession` 之前完成。
|
||||
* `numThread` 的数量最好设置为与绑定的 CPU 核心数量一致,以达到最佳的性能。
|
||||
* 如果指定的 CPU 核心 ID 不存在或无效,MNN 将会忽略该配置,并使用默认的线程调度策略。
|
||||
|
||||
### 创建多段路径Session
|
||||
需要对推理路径做出更为复杂的配置时,可以通过调度配置组来实现:
|
||||
```cpp
|
||||
|
|
|
@ -11,9 +11,11 @@ MNN是Pymnn中最基础的Module,其中包含了V2 API所需要数据结构与
|
|||
- [expr](expr.md)
|
||||
- [numpy](numpy.md)
|
||||
- [cv](cv.md)
|
||||
- [nn](nn.md)
|
||||
- [nn](nn.md): 包含 `loss` 和 `compress` 子模块
|
||||
- [optim](optim.md)
|
||||
- [data](data.md)
|
||||
- [audio](audio.md)
|
||||
- [llm](llm.md)
|
||||
|
||||
---
|
||||
### `MNN Types`
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# audio
|
||||
|
||||
This document is a placeholder.
|
|
@ -0,0 +1,3 @@
|
|||
# llm
|
||||
|
||||
This document is a placeholder.
|
|
@ -103,7 +103,7 @@ if (MNN_KLEIDIAI)
|
|||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c
|
||||
)
|
||||
|
||||
set_source_files_properties(${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm+dotprod+sve+sve2+fp16)
|
||||
set_source_files_properties(${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+i8mm+dotprod+sve+sve2+fp16")
|
||||
set_source_files_properties(${KLEIDIAI_FILES_SME2} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2")
|
||||
|
||||
endif()
|
||||
|
|
|
@ -211,6 +211,11 @@ namespace MNN {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if(type == AccelType::QI4_SYM_CHNLQT_F32){
|
||||
if(common->inputCount() % 2 != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(common->kernelX() == 1 && common->kernelY() == 1
|
||||
&& common->padX() == 0 && common->padY() == 0
|
||||
&& common->strideX() == 1 && common->strideY() == 1
|
||||
|
|
|
@ -175,8 +175,11 @@ ErrorCode KleidiAIConvolution::onResize(const std::vector<Tensor *> &inputs, con
|
|||
if (outputOriginFmt != MNN_DATA_FORMAT_NHWC){
|
||||
b->onReleaseBuffer(mOutputConvertBuffer.get(), Backend::DYNAMIC);
|
||||
}
|
||||
|
||||
mPostParameters = getPostParameters();
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode KleidiAIConvolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
@ -209,13 +212,12 @@ ErrorCode KleidiAIConvolution::onExecute(const std::vector<Tensor *> &inputs, co
|
|||
}
|
||||
|
||||
auto outputDes = TensorUtils::getDescribe(outputs[0]);
|
||||
auto postPtr = getPostParameters();
|
||||
auto outputPtr = output->host<uint8_t>();
|
||||
if(outputDes->dimensionFormat != MNN_DATA_FORMAT_NHWC){
|
||||
outputPtr = mOutputConvertBuffer->host<uint8_t>();
|
||||
}
|
||||
|
||||
kai.runMatmul(mAccelType, m, n, k, 0, lhsPacked, weightPtr, outputPtr, n * elementSize, elementSize, postPtr[3], postPtr[2]);
|
||||
kai.runMatmul(mAccelType, m, n, k, 0, lhsPacked, weightPtr, outputPtr, n * elementSize, elementSize, mPostParameters[3], mPostParameters[2]);
|
||||
|
||||
if(outputDes->dimensionFormat != MNN_DATA_FORMAT_NHWC){
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
|
|
|
@ -28,7 +28,7 @@ class KleidiAIConvolution : public CPUConvolution{
|
|||
std::shared_ptr<Tensor> mOutputConvertBuffer;
|
||||
std::shared_ptr<CPUConvolution::Resource> mResource;
|
||||
KleidiAI::AccelType mAccelType = KleidiAI::AccelType::ACC_TYPE_NUMBER;
|
||||
|
||||
std::vector<float> mPostParameters;
|
||||
};
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ struct ConvParams {
|
|||
int dilatedKernelWidth = kernelSizeWithDilated(kernelWidth, dilatedWidth);
|
||||
|
||||
int outputHeight = outputSize(inputHeight, padTop, padBottom, dilatedKernelHeight, strideHeight);
|
||||
int outputWidth = outputSize(inputHeight, padLeft, padRight, dilatedKernelWidth, strideWidth);
|
||||
int outputWidth = outputSize(inputWidth, padLeft, padRight, dilatedKernelWidth, strideWidth);
|
||||
|
||||
return {outputHeight, outputWidth};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue