mirror of https://github.com/alibaba/MNN.git
[feat] change data source
This commit is contained in:
parent
3d2091bc24
commit
85855f7dbc
|
@ -102,7 +102,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
|
|||
PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
|
||||
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
|
||||
|
||||
if let isDeepSeek = self?.modelInfo.name.lowercased().contains("deepseek"), isDeepSeek == true,
|
||||
if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
|
||||
let text = self?.processor.process(progress: message.text) {
|
||||
updateLastMsg?.text = text
|
||||
} else {
|
||||
|
|
|
@ -25,13 +25,13 @@ final class LLMChatData {
|
|||
|
||||
self.assistant = LLMChatUser(
|
||||
uid: "2",
|
||||
name: modelInfo.name,
|
||||
name: modelInfo.modelName,
|
||||
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
||||
)
|
||||
|
||||
self.system = LLMChatUser(
|
||||
uid: "0",
|
||||
name: modelInfo.name,
|
||||
name: modelInfo.modelName,
|
||||
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
||||
)
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
let modelConfigManager: ModelConfigManager
|
||||
|
||||
var isDiffusionModel: Bool {
|
||||
return modelInfo.name.lowercased().contains("diffusion")
|
||||
return modelInfo.modelName.lowercased().contains("diffusion")
|
||||
}
|
||||
|
||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||
|
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
), userType: .system)
|
||||
}
|
||||
|
||||
if modelInfo.name.lowercased().contains("diffusion") {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
||||
Task { @MainActor in
|
||||
print("Diffusion Model \(success)")
|
||||
|
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
func sendToLLM(draft: DraftMessage) {
|
||||
self.send(draft: draft, userType: .user)
|
||||
if isModelLoaded {
|
||||
if modelInfo.name.lowercased().contains("diffusion") {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
self.getDiffusionResponse(draft: draft)
|
||||
} else {
|
||||
self.getLLMRespsonse(draft: draft)
|
||||
|
@ -284,7 +284,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
}
|
||||
|
||||
private func convertDeepSeekMutliChat(content: String) -> String {
|
||||
if self.modelInfo.name.lowercased().contains("deepseek") {
|
||||
if self.modelInfo.modelName.lowercased().contains("deepseek") {
|
||||
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||
*/
|
||||
|
@ -337,7 +337,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
ChatHistoryManager.shared.saveChat(
|
||||
historyId: historyId,
|
||||
modelId: modelInfo.modelId,
|
||||
modelName: modelInfo.name,
|
||||
modelName: modelInfo.modelName,
|
||||
messages: messages
|
||||
)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
let modelConfigManager: ModelConfigManager
|
||||
|
||||
var isDiffusionModel: Bool {
|
||||
return modelInfo.name.lowercased().contains("diffusion")
|
||||
return modelInfo.modelName.lowercased().contains("diffusion")
|
||||
}
|
||||
|
||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||
|
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
), userType: .system)
|
||||
}
|
||||
|
||||
if modelInfo.name.lowercased().contains("diffusion") {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
||||
Task { @MainActor in
|
||||
print("Diffusion Model \(success)")
|
||||
|
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
func sendToLLM(draft: DraftMessage) {
|
||||
self.send(draft: draft, userType: .user)
|
||||
if isModelLoaded {
|
||||
if modelInfo.name.lowercased().contains("diffusion") {
|
||||
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||
self.getDiffusionResponse(draft: draft)
|
||||
} else {
|
||||
self.getLLMRespsonse(draft: draft)
|
||||
|
@ -298,7 +298,7 @@ final class LLMChatViewModel: ObservableObject {
|
|||
}
|
||||
|
||||
private func convertDeepSeekMutliChat(content: String) -> String {
|
||||
if self.modelInfo.name.lowercased().contains("deepseek") {
|
||||
if self.modelInfo.modelName.lowercased().contains("deepseek") {
|
||||
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||
*/
|
||||
|
@ -350,8 +350,8 @@ final class LLMChatViewModel: ObservableObject {
|
|||
func onStop() {
|
||||
ChatHistoryManager.shared.saveChat(
|
||||
historyId: historyId,
|
||||
modelId: modelInfo.modelId,
|
||||
modelName: modelInfo.name,
|
||||
modelId: modelInfo.id,
|
||||
modelName: modelInfo.modelName,
|
||||
messages: messages
|
||||
)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ struct LLMChatView: View {
|
|||
@State private var showSettings = false
|
||||
|
||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||
self.title = modelInfo.name
|
||||
self.title = modelInfo.modelName
|
||||
self.modelPath = modelInfo.localPath
|
||||
let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history)
|
||||
_viewModel = StateObject(wrappedValue: viewModel)
|
||||
|
|
|
@ -12,14 +12,14 @@ struct LocalModelListView: View {
|
|||
|
||||
var body: some View {
|
||||
List {
|
||||
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.modelId) { model in
|
||||
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.id) { model in
|
||||
Button(action: {
|
||||
viewModel.selectModel(model)
|
||||
}) {
|
||||
LocalModelRowView(model: model)
|
||||
}
|
||||
.listRowSeparator(.hidden)
|
||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.modelId) ? Color.black.opacity(0.05) : Color.clear)
|
||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
|
||||
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
||||
SwipeActionsView(model: model, viewModel: viewModel)
|
||||
}
|
||||
|
|
|
@ -11,36 +11,28 @@ struct LocalModelRowView: View {
|
|||
|
||||
let model: ModelInfo
|
||||
|
||||
private var localizedTags: [String] {
|
||||
model.localizedTags
|
||||
}
|
||||
|
||||
private var formattedSize: String {
|
||||
model.formattedSize
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .center) {
|
||||
|
||||
ModelIconView(modelId: model.modelId)
|
||||
ModelIconView(modelId: model.id)
|
||||
.frame(width: 50, height: 50)
|
||||
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
Text(model.name)
|
||||
Text(model.modelName)
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
.lineLimit(1)
|
||||
|
||||
if !model.tags.isEmpty {
|
||||
ScrollView(.horizontal, showsIndicators: false) {
|
||||
HStack {
|
||||
ForEach(model.tags, id: \.self) { tag in
|
||||
Text(tag)
|
||||
.fontWeight(.regular)
|
||||
.font(.caption)
|
||||
.foregroundColor(Color(red: 151/255, green: 151/255, blue: 151/255))
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 4)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 10)
|
||||
.stroke(Color(red: 151/255, green: 151/255, blue: 151/255), lineWidth: 0.5)
|
||||
.padding(1)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !localizedTags.isEmpty {
|
||||
TagsView(tags: localizedTags)
|
||||
}
|
||||
|
||||
HStack {
|
||||
|
@ -51,7 +43,7 @@ struct LocalModelRowView: View {
|
|||
.foregroundColor(.gray)
|
||||
.frame(width: 20, height: 20)
|
||||
|
||||
Text(model.formattedSize)
|
||||
Text(formattedSize)
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
|
|
|
@ -16,39 +16,32 @@ struct MainTabView: View {
|
|||
@State private var showWebView = false
|
||||
@State private var webViewURL: URL?
|
||||
@State private var navigateToSettings = false
|
||||
@StateObject private var modelListViewModel = TBModelListViewModel()
|
||||
@StateObject private var localModelListViewModel = ModelListViewModel()
|
||||
@StateObject private var modelListViewModel = ModelListViewModel()
|
||||
@State private var selectedTab: Int = 0
|
||||
@State private var titles = ["本地模型", "模型市场", "TB模型", "Benchmark"]
|
||||
@State private var titles = ["本地模型", "模型市场", "Benchmark"]
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
NavigationView {
|
||||
TabView(selection: $selectedTab) {
|
||||
LocalModelListView(viewModel: localModelListViewModel)
|
||||
LocalModelListView(viewModel: modelListViewModel)
|
||||
.tabItem {
|
||||
Image(systemName: "house.fill")
|
||||
Text("本地模型")
|
||||
}
|
||||
.tag(0)
|
||||
ModelListView(viewModel: localModelListViewModel)
|
||||
ModelListView(viewModel: modelListViewModel)
|
||||
.tabItem {
|
||||
Image(systemName: "cart.fill")
|
||||
Image(systemName: "doc.text.fill")
|
||||
Text("模型市场")
|
||||
}
|
||||
.tag(1)
|
||||
TBModelListView(viewModel: modelListViewModel)
|
||||
.tabItem {
|
||||
Image(systemName: "doc.text.fill")
|
||||
Text("TB模型")
|
||||
}
|
||||
.tag(2)
|
||||
BenchmarkView()
|
||||
.tabItem {
|
||||
Image(systemName: "clock.fill")
|
||||
Text("Benchmark")
|
||||
}
|
||||
.tag(3)
|
||||
.tag(2)
|
||||
}
|
||||
.background(
|
||||
ZStack {
|
||||
|
@ -130,19 +123,13 @@ struct MainTabView: View {
|
|||
|
||||
@ViewBuilder
|
||||
private var chatDestination: some View {
|
||||
if let model = localModelListViewModel.selectedModel {
|
||||
if let model = modelListViewModel.selectedModel {
|
||||
LLMChatView(modelInfo: model)
|
||||
.navigationBarHidden(false)
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar(.hidden, for: .tabBar) // Hide tab bar in chat
|
||||
} else if let history = selectedHistory {
|
||||
let modelInfo = ModelInfo(
|
||||
modelId: history.modelId,
|
||||
createdAt: "",
|
||||
downloads: 0,
|
||||
tags: [],
|
||||
isDownloaded: true
|
||||
)
|
||||
let modelInfo = ModelInfo(modelId: history.modelId, isDownloaded: true)
|
||||
LLMChatView(modelInfo: modelInfo, history: history)
|
||||
.navigationBarHidden(false)
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
|
@ -155,17 +142,17 @@ struct MainTabView: View {
|
|||
private var chatIsActiveBinding: Binding<Bool> {
|
||||
Binding<Bool>(
|
||||
get: {
|
||||
return localModelListViewModel.selectedModel != nil || selectedHistory != nil
|
||||
return modelListViewModel.selectedModel != nil || selectedHistory != nil
|
||||
},
|
||||
set: { isActive in
|
||||
if !isActive {
|
||||
// Record usage when returning from chat
|
||||
if let model = localModelListViewModel.selectedModel {
|
||||
localModelListViewModel.recordModelUsage(modelName: model.name)
|
||||
if let model = modelListViewModel.selectedModel {
|
||||
modelListViewModel.recordModelUsage(modelName: model.modelName)
|
||||
}
|
||||
|
||||
// Clear selections
|
||||
localModelListViewModel.selectedModel = nil
|
||||
modelListViewModel.selectedModel = nil
|
||||
selectedHistory = nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,83 +1,108 @@
|
|||
//
|
||||
// ModelClient.swift
|
||||
// TBModelInfo.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Hub
|
||||
import Foundation
|
||||
|
||||
struct ModelInfo: Codable {
|
||||
let modelId: String
|
||||
let createdAt: String
|
||||
let downloads: Int
|
||||
// MARK: - Properties
|
||||
let modelName: String
|
||||
let tags: [String]
|
||||
let categories: [String]?
|
||||
let size_gb: Double?
|
||||
let vendor: String?
|
||||
let sources: [String: String]?
|
||||
let tagTranslations: [String: [String]]?
|
||||
|
||||
var name: String {
|
||||
modelId.removingTaobaoPrefix()
|
||||
}
|
||||
|
||||
// Runtime properties
|
||||
var isDownloaded: Bool = false
|
||||
var lastUsedAt: Date?
|
||||
|
||||
var cachedSize: Int64? = nil
|
||||
|
||||
var localPath: String {
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelId)).path
|
||||
// MARK: - Initialization
|
||||
|
||||
init(modelName: String = "",
|
||||
tags: [String] = [],
|
||||
categories: [String]? = nil,
|
||||
size_gb: Double? = nil,
|
||||
vendor: String? = nil,
|
||||
sources: [String: String]? = nil,
|
||||
tagTranslations: [String: [String]]? = nil,
|
||||
isDownloaded: Bool = false,
|
||||
lastUsedAt: Date? = nil,
|
||||
cachedSize: Int64? = nil) {
|
||||
|
||||
self.modelName = modelName
|
||||
self.tags = tags
|
||||
self.categories = categories
|
||||
self.size_gb = size_gb
|
||||
self.vendor = vendor
|
||||
self.sources = sources
|
||||
self.tagTranslations = tagTranslations
|
||||
self.isDownloaded = isDownloaded
|
||||
self.lastUsedAt = lastUsedAt
|
||||
self.cachedSize = cachedSize
|
||||
}
|
||||
|
||||
init(modelId: String, isDownloaded: Bool = true) {
|
||||
let modelName = modelId.components(separatedBy: "/").last ?? modelId
|
||||
|
||||
self.init(
|
||||
modelName: modelName,
|
||||
tags: [],
|
||||
sources: ["huggingface": modelId],
|
||||
isDownloaded: isDownloaded
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Model Identity & Localization
|
||||
|
||||
var id: String {
|
||||
guard let sources = sources else {
|
||||
return "taobao-mnn/\(modelName)"
|
||||
}
|
||||
|
||||
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
|
||||
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
||||
}
|
||||
|
||||
var localizedTags: [String] {
|
||||
let currentLanguage = LanguageManager.shared.currentLanguage
|
||||
let isChineseLanguage = currentLanguage == "简体中文"
|
||||
|
||||
if isChineseLanguage, let translations = tagTranslations {
|
||||
let languageCode = "zh-Hans"
|
||||
return translations[languageCode] ?? tags
|
||||
} else {
|
||||
return tags
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - File System & Path Management
|
||||
|
||||
var localPath: String {
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
|
||||
}
|
||||
|
||||
// MARK: - Size Calculation & Formatting
|
||||
|
||||
var formattedSize: String {
|
||||
if isDownloaded {
|
||||
return formatLocalSize()
|
||||
} else if let cached = cachedSize {
|
||||
if let cached = cachedSize {
|
||||
return formatBytes(cached)
|
||||
} else if isDownloaded {
|
||||
return formatLocalSize()
|
||||
} else if let sizeGb = size_gb {
|
||||
return String(format: "%.1f GB", sizeGb)
|
||||
} else {
|
||||
return "计算中..."
|
||||
return "Calculating..."
|
||||
}
|
||||
}
|
||||
|
||||
func fetchRemoteSize() async -> Int64? {
|
||||
let modelScopeId = modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
|
||||
|
||||
do {
|
||||
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
|
||||
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
|
||||
return totalSize
|
||||
} catch {
|
||||
print("Error fetching remote size for \(modelId): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func formatLocalSize() -> String {
|
||||
let path = localPath
|
||||
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
|
||||
|
||||
do {
|
||||
let totalSize = try calculateDirectorySize(at: path)
|
||||
return formatBytes(totalSize)
|
||||
} catch {
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
||||
let fileManager = FileManager.default
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
let enumerator = fileManager.enumerator(atPath: path)
|
||||
while let fileName = enumerator?.nextObject() as? String {
|
||||
let filePath = (path as NSString).appendingPathComponent(fileName)
|
||||
let attributes = try fileManager.attributesOfItem(atPath: filePath)
|
||||
if let fileSize = attributes[.size] as? Int64 {
|
||||
totalSize += fileSize
|
||||
}
|
||||
}
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
private func formatBytes(_ bytes: Int64) -> String {
|
||||
let formatter = ByteCountFormatter()
|
||||
formatter.allowedUnits = [.useGB]
|
||||
|
@ -85,7 +110,102 @@ struct ModelInfo: Codable {
|
|||
return formatter.string(fromByteCount: bytes)
|
||||
}
|
||||
|
||||
// MARK: - 云端文件大小计算方法
|
||||
// MARK: - Local Size Calculation
|
||||
|
||||
private func formatLocalSize() -> String {
|
||||
let path = localPath
|
||||
guard FileManager.default.fileExists(atPath: path) else { return "Unknown" }
|
||||
|
||||
do {
|
||||
let totalSize = try calculateDirectorySize(at: path)
|
||||
return formatBytes(totalSize)
|
||||
} catch {
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
||||
let fileManager = FileManager.default
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
print("Calculating directory size for path: \(path)")
|
||||
|
||||
let directoryURL = URL(fileURLWithPath: path)
|
||||
|
||||
guard fileManager.fileExists(atPath: path) else {
|
||||
print("Path does not exist: \(path)")
|
||||
return 0
|
||||
}
|
||||
|
||||
let resourceKeys: [URLResourceKey] = [.isRegularFileKey, .totalFileAllocatedSizeKey, .fileSizeKey, .nameKey]
|
||||
let enumerator = fileManager.enumerator(
|
||||
at: directoryURL,
|
||||
includingPropertiesForKeys: resourceKeys,
|
||||
options: [.skipsHiddenFiles, .skipsPackageDescendants],
|
||||
errorHandler: { (url, error) -> Bool in
|
||||
print("Error accessing \(url): \(error)")
|
||||
return true
|
||||
}
|
||||
)
|
||||
|
||||
guard let fileEnumerator = enumerator else {
|
||||
throw NSError(domain: "FileEnumerationError", code: -1,
|
||||
userInfo: [NSLocalizedDescriptionKey: "Failed to create file enumerator"])
|
||||
}
|
||||
|
||||
var fileCount = 0
|
||||
for case let fileURL as URL in fileEnumerator {
|
||||
do {
|
||||
let resourceValues = try fileURL.resourceValues(forKeys: Set(resourceKeys))
|
||||
|
||||
guard let isRegularFile = resourceValues.isRegularFile, isRegularFile else { continue }
|
||||
|
||||
let fileName = resourceValues.name ?? "Unknown"
|
||||
fileCount += 1
|
||||
|
||||
// Use actual disk allocated size, fallback to logical size if not available
|
||||
if let actualSize = resourceValues.totalFileAllocatedSize {
|
||||
totalSize += Int64(actualSize)
|
||||
|
||||
if fileCount <= 10 {
|
||||
let actualSizeGB = Double(actualSize) / (1024 * 1024 * 1024)
|
||||
let logicalSizeGB = Double(resourceValues.fileSize ?? 0) / (1024 * 1024 * 1024)
|
||||
print("File \(fileCount): \(fileName) - Logical: \(String(format: "%.3f", logicalSizeGB)) GB, Actual: \(String(format: "%.3f", actualSizeGB)) GB")
|
||||
}
|
||||
} else if let logicalSize = resourceValues.fileSize {
|
||||
totalSize += Int64(logicalSize)
|
||||
|
||||
if fileCount <= 10 {
|
||||
let logicalSizeGB = Double(logicalSize) / (1024 * 1024 * 1024)
|
||||
print("File \(fileCount): \(fileName) - Size: \(String(format: "%.3f", logicalSizeGB)) GB (fallback)")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("Error getting resource values for \(fileURL): \(error)")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
let totalSizeGB = Double(totalSize) / (1024 * 1024 * 1024)
|
||||
print("Total files: \(fileCount), Total actual disk usage: \(String(format: "%.2f", totalSizeGB)) GB")
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
// MARK: - Remote Size Calculation
|
||||
|
||||
func fetchRemoteSize() async -> Int64? {
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
|
||||
do {
|
||||
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
|
||||
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
|
||||
return totalSize
|
||||
} catch {
|
||||
print("Error fetching remote size for \(id): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
||||
let url = try buildURL(
|
||||
|
@ -119,6 +239,8 @@ struct ModelInfo: Codable {
|
|||
return totalSize
|
||||
}
|
||||
|
||||
// MARK: - Network Utilities
|
||||
|
||||
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
||||
var components = URLComponents()
|
||||
components.scheme = "https"
|
||||
|
@ -139,21 +261,9 @@ struct ModelInfo: Codable {
|
|||
}
|
||||
}
|
||||
|
||||
// MARK: - Codable
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case modelId
|
||||
case tags
|
||||
case downloads
|
||||
case createdAt
|
||||
case cachedSize
|
||||
}
|
||||
}
|
||||
|
||||
struct RepoInfo: Codable {
|
||||
let modelId: String
|
||||
let sha: String
|
||||
let siblings: [Sibling]
|
||||
|
||||
struct Sibling: Codable {
|
||||
let rfilename: String
|
||||
case modelName, tags, categories, size_gb, vendor, sources, tagTranslations, cachedSize
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// ModelListViewModel.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
@ -10,74 +10,78 @@ import SwiftUI
|
|||
|
||||
@MainActor
|
||||
class ModelListViewModel: ObservableObject {
|
||||
// MARK: - Published Properties
|
||||
@Published var models: [ModelInfo] = []
|
||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
||||
@Published private(set) var currentlyDownloading: String?
|
||||
@Published var searchText = ""
|
||||
@Published var quickFilterTags: [String] = []
|
||||
@Published var selectedModel: ModelInfo?
|
||||
@Published var showError = false
|
||||
@Published var errorMessage = ""
|
||||
@Published var searchText = ""
|
||||
|
||||
@Published var selectedModel: ModelInfo?
|
||||
// Download state
|
||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
||||
@Published private(set) var currentlyDownloading: String?
|
||||
|
||||
// MARK: - Private Properties
|
||||
private let modelClient = ModelClient()
|
||||
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
|
||||
|
||||
// MARK: - Model Data Access
|
||||
|
||||
public var pinnedModelIds: [String] {
|
||||
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
|
||||
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
|
||||
}
|
||||
|
||||
var filteredModels: [ModelInfo] {
|
||||
var allTags: [String] {
|
||||
Array(Set(models.flatMap { $0.tags }))
|
||||
}
|
||||
|
||||
let filteredModels = searchText.isEmpty ? models : models.filter { model in
|
||||
model.modelId.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.tags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||
var allCategories: [String] {
|
||||
Array(Set(models.compactMap { $0.categories }.flatMap { $0 }))
|
||||
}
|
||||
|
||||
var allVendors: [String] {
|
||||
Array(Set(models.compactMap { $0.vendor }))
|
||||
}
|
||||
|
||||
var filteredModels: [ModelInfo] {
|
||||
let filtered = searchText.isEmpty ? models : models.filter { model in
|
||||
model.id.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||
}
|
||||
|
||||
let downloadedModels = filteredModels.filter { $0.isDownloaded }
|
||||
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded }
|
||||
let downloaded = filtered.filter { $0.isDownloaded }
|
||||
let notDownloaded = filtered.filter { !$0.isDownloaded }
|
||||
|
||||
return downloadedModels + notDownloadedModels
|
||||
return downloaded + notDownloaded
|
||||
}
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
init() {
|
||||
Task {
|
||||
await fetchModels()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Data Management
|
||||
|
||||
func fetchModels() async {
|
||||
do {
|
||||
var fetchedModels = try await modelClient.getModelList()
|
||||
let info = try await modelClient.getModelInfo()
|
||||
|
||||
let hasDiffusionModels = fetchedModels.contains {
|
||||
$0.name.lowercased().contains("diffusion")
|
||||
}
|
||||
self.quickFilterTags = info.quickFilterTags ?? []
|
||||
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
|
||||
|
||||
if hasDiffusionModels {
|
||||
fetchedModels = fetchedModels.filter { model in
|
||||
let name = model.name.lowercased()
|
||||
let tags = model.tags.map { $0.lowercased() }
|
||||
var fetchedModels = info.models
|
||||
|
||||
// only show gpu diffusion
|
||||
if name.contains("diffusion") {
|
||||
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..<fetchedModels.count {
|
||||
let model = fetchedModels[i]
|
||||
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.name)
|
||||
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.name)
|
||||
}
|
||||
|
||||
// Sort models
|
||||
filterDiffusionModels(fetchedModels: &fetchedModels)
|
||||
sortModels(fetchedModels: &fetchedModels)
|
||||
self.models = fetchedModels
|
||||
|
||||
// 异步获取未下载模型的大小信息
|
||||
// Asynchronously fetch size info for undownloaded models
|
||||
Task {
|
||||
await fetchModelSizes(for: fetchedModels)
|
||||
}
|
||||
|
@ -91,12 +95,12 @@ class ModelListViewModel: ObservableObject {
|
|||
private func fetchModelSizes(for models: [ModelInfo]) async {
|
||||
await withTaskGroup(of: Void.self) { group in
|
||||
for (_, model) in models.enumerated() {
|
||||
if !model.isDownloaded && model.cachedSize == nil {
|
||||
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
|
||||
group.addTask {
|
||||
if let size = await model.fetchRemoteSize() {
|
||||
await MainActor.run {
|
||||
// 查找当前模型在实际数组中的索引
|
||||
if let modelIndex = self.models.firstIndex(where: { $0.modelId == model.modelId }) {
|
||||
// Find current model index in actual array
|
||||
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[modelIndex].cachedSize = size
|
||||
}
|
||||
}
|
||||
|
@ -107,11 +111,29 @@ class ModelListViewModel: ObservableObject {
|
|||
}
|
||||
}
|
||||
|
||||
func recordModelUsage(modelName: String) {
|
||||
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
||||
if let index = models.firstIndex(where: { $0.name == modelName }) {
|
||||
models[index].lastUsedAt = Date()
|
||||
sortModels(fetchedModels: &models)
|
||||
private func filterDiffusionModels(fetchedModels: inout [ModelInfo]) {
|
||||
let hasDiffusionModels = fetchedModels.contains {
|
||||
$0.modelName.lowercased().contains("diffusion")
|
||||
}
|
||||
|
||||
if hasDiffusionModels {
|
||||
fetchedModels = fetchedModels.filter { model in
|
||||
let name = model.modelName.lowercased()
|
||||
let tags = model.tags.map { $0.lowercased() }
|
||||
|
||||
// Only show GPU diffusion models
|
||||
if name.contains("diffusion") {
|
||||
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..<fetchedModels.count {
|
||||
let model = fetchedModels[i]
|
||||
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
|
||||
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -119,24 +141,34 @@ class ModelListViewModel: ObservableObject {
|
|||
let pinned = pinnedModelIds
|
||||
|
||||
fetchedModels.sort { (model1, model2) -> Bool in
|
||||
let isPinned1 = pinned.contains(model1.modelId)
|
||||
let isPinned2 = pinned.contains(model2.modelId)
|
||||
let isPinned1 = pinned.contains(model1.id)
|
||||
let isPinned2 = pinned.contains(model2.id)
|
||||
let isDownloading1 = currentlyDownloading == model1.id
|
||||
let isDownloading2 = currentlyDownloading == model2.id
|
||||
|
||||
// 1. Currently downloading models have highest priority
|
||||
if isDownloading1 != isDownloading2 {
|
||||
return isDownloading1
|
||||
}
|
||||
|
||||
// 2. Pinned models have second priority
|
||||
if isPinned1 != isPinned2 {
|
||||
return isPinned1
|
||||
}
|
||||
|
||||
// 3. If both are pinned, sort by pin time
|
||||
if isPinned1 && isPinned2 {
|
||||
let index1 = pinned.firstIndex(of: model1.modelId)!
|
||||
let index2 = pinned.firstIndex(of: model2.modelId)!
|
||||
let index1 = pinned.firstIndex(of: model1.id)!
|
||||
let index2 = pinned.firstIndex(of: model2.id)!
|
||||
return index1 > index2 // Pinned later comes first
|
||||
}
|
||||
|
||||
// Non-pinned models
|
||||
// 4. Non-pinned models sorted by download status
|
||||
if model1.isDownloaded != model2.isDownloaded {
|
||||
return model1.isDownloaded
|
||||
}
|
||||
|
||||
// 5. If both downloaded, sort by last used time
|
||||
if model1.isDownloaded {
|
||||
let date1 = model1.lastUsedAt ?? .distantPast
|
||||
let date2 = model2.lastUsedAt ?? .distantPast
|
||||
|
@ -145,10 +177,10 @@ class ModelListViewModel: ObservableObject {
|
|||
|
||||
return false // Keep original order for not-downloaded
|
||||
}
|
||||
|
||||
models = fetchedModels
|
||||
}
|
||||
|
||||
// MARK: - Model Selection & Usage
|
||||
|
||||
func selectModel(_ model: ModelInfo) {
|
||||
if model.isDownloaded {
|
||||
selectedModel = model
|
||||
|
@ -159,26 +191,35 @@ class ModelListViewModel: ObservableObject {
|
|||
}
|
||||
}
|
||||
|
||||
func recordModelUsage(modelName: String) {
|
||||
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
||||
if let index = models.firstIndex(where: { $0.modelName == modelName }) {
|
||||
models[index].lastUsedAt = Date()
|
||||
sortModels(fetchedModels: &models)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Download Management
|
||||
|
||||
func downloadModel(_ model: ModelInfo) async {
|
||||
guard currentlyDownloading == nil else { return }
|
||||
|
||||
currentlyDownloading = model.modelId
|
||||
downloadProgress[model.modelId] = 0
|
||||
currentlyDownloading = model.id
|
||||
downloadProgress[model.id] = 0
|
||||
|
||||
do {
|
||||
try await modelClient.downloadModel(model: model) { progress in
|
||||
Task { @MainActor in
|
||||
self.downloadProgress[model.modelId] = progress
|
||||
self.downloadProgress[model.id] = progress
|
||||
}
|
||||
}
|
||||
|
||||
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) {
|
||||
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
||||
models[index].isDownloaded = true
|
||||
ModelStorageManager.shared.markModelAsDownloaded(model.name)
|
||||
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
|
||||
}
|
||||
|
||||
} catch {
|
||||
|
||||
if case ModelScopeError.downloadCancelled = error {
|
||||
print("Download was cancelled")
|
||||
} else {
|
||||
|
@ -188,7 +229,7 @@ class ModelListViewModel: ObservableObject {
|
|||
}
|
||||
|
||||
currentlyDownloading = nil
|
||||
downloadProgress.removeValue(forKey: model.modelId)
|
||||
downloadProgress.removeValue(forKey: model.id)
|
||||
}
|
||||
|
||||
func cancelDownload() async {
|
||||
|
@ -202,23 +243,27 @@ class ModelListViewModel: ObservableObject {
|
|||
}
|
||||
}
|
||||
|
||||
// MARK: - Pin Management
|
||||
|
||||
func pinModel(_ model: ModelInfo) {
|
||||
guard let index = models.firstIndex(where: { $0.modelId == model.modelId }) else { return }
|
||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||
let pinned = models.remove(at: index)
|
||||
models.insert(pinned, at: 0)
|
||||
var ids = pinnedModelIds.filter { $0 != model.modelId }
|
||||
ids.append(model.modelId)
|
||||
var ids = pinnedModelIds.filter { $0 != model.id }
|
||||
ids.append(model.id)
|
||||
pinnedModelIds = ids
|
||||
}
|
||||
|
||||
func unpinModel(_ model: ModelInfo) {
|
||||
guard let index = models.firstIndex(where: { $0.modelId == model.modelId }) else { return }
|
||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||
let unpinned = models.remove(at: index)
|
||||
let insertIndex = models.count // 取消置顶后放到未置顶最后
|
||||
let insertIndex = models.count // Insert at end after unpinning
|
||||
models.insert(unpinned, at: insertIndex)
|
||||
pinnedModelIds = pinnedModelIds.filter { $0 != model.modelId }
|
||||
pinnedModelIds = pinnedModelIds.filter { $0 != model.id }
|
||||
}
|
||||
|
||||
// MARK: - Model Deletion
|
||||
|
||||
func deleteModel(_ model: ModelInfo) async {
|
||||
do {
|
||||
let fileManager = FileManager.default
|
||||
|
@ -240,11 +285,11 @@ class ModelListViewModel: ObservableObject {
|
|||
}
|
||||
|
||||
await MainActor.run {
|
||||
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) {
|
||||
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
||||
models[index].isDownloaded = false
|
||||
ModelStorageManager.shared.clearDownloadStatus(for: model.name)
|
||||
ModelStorageManager.shared.clearDownloadStatus(for: model.modelName)
|
||||
}
|
||||
if selectedModel?.modelId == model.modelId {
|
||||
if selectedModel?.id == model.id {
|
||||
selectedModel = nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
//
|
||||
// TBDataResponse.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/9.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
|
||||
struct TBDataResponse: Codable {
|
||||
let tagTranslations: [String: String]
|
||||
let quickFilterTags: [String]?
|
||||
let models: [ModelInfo]
|
||||
let metadata: Metadata?
|
||||
|
||||
struct Metadata: Codable {
|
||||
let version: String
|
||||
let lastUpdated: String
|
||||
let schemaVersion: String
|
||||
let totalModels: Int
|
||||
let supportedPlatforms: [String]
|
||||
let minAppVersion: String
|
||||
}
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
//
|
||||
// TBModelInfo.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Hub
|
||||
import Foundation
|
||||
|
||||
struct TBModelInfo: Codable {
|
||||
let modelName: String
|
||||
let tags: [String]
|
||||
let categories: [String]?
|
||||
let size_gb: Double?
|
||||
let vendor: String?
|
||||
let sources: [String: String]?
|
||||
let tagTranslations: [String: [String]]?
|
||||
|
||||
// 运行时属性
|
||||
var isDownloaded: Bool = false
|
||||
var lastUsedAt: Date?
|
||||
var cachedSize: Int64? = nil
|
||||
|
||||
// 统一的ID属性,根据当前选择的源获取对应的modelId
|
||||
var id: String {
|
||||
guard let sources = sources else {
|
||||
return "taobao-mnn/\(modelName)"
|
||||
}
|
||||
|
||||
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
|
||||
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
||||
}
|
||||
|
||||
// 本地化的标签 - 使用模型自己的tagTranslations
|
||||
var localizedTags: [String] {
|
||||
let currentLanguage = LanguageManager.shared.currentLanguage
|
||||
let isChineseLanguage = currentLanguage == "简体中文"
|
||||
|
||||
if isChineseLanguage, let translations = tagTranslations {
|
||||
let languageCode = "zh-Hans"
|
||||
return translations[languageCode] ?? tags
|
||||
} else {
|
||||
return tags
|
||||
}
|
||||
}
|
||||
|
||||
var localPath: String {
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: id)).path
|
||||
}
|
||||
|
||||
var formattedSize: String {
|
||||
if isDownloaded {
|
||||
return formatLocalSize()
|
||||
} else if let cached = cachedSize {
|
||||
return formatBytes(cached)
|
||||
} else if let sizeGb = size_gb {
|
||||
return String(format: "%.1f GB", sizeGb)
|
||||
} else {
|
||||
return "计算中..."
|
||||
}
|
||||
}
|
||||
|
||||
func fetchRemoteSize() async -> Int64? {
|
||||
// TODO: now only support modelScope, support huggingFace later
|
||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||
|
||||
do {
|
||||
let files = try await fetchFileList(repoPath: id, root: "", revision: "")
|
||||
let totalSize = try await calculateTotalSize(files: files, repoPath: id)
|
||||
return totalSize
|
||||
} catch {
|
||||
print("Error fetching remote size for \(id): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func formatLocalSize() -> String {
|
||||
let path = localPath
|
||||
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
|
||||
|
||||
do {
|
||||
let totalSize = try calculateDirectorySize(at: path)
|
||||
return formatBytes(totalSize)
|
||||
} catch {
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
||||
let fileManager = FileManager.default
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
let enumerator = fileManager.enumerator(atPath: path)
|
||||
while let fileName = enumerator?.nextObject() as? String {
|
||||
let filePath = (path as NSString).appendingPathComponent(fileName)
|
||||
let attributes = try fileManager.attributesOfItem(atPath: filePath)
|
||||
if let fileSize = attributes[.size] as? Int64 {
|
||||
totalSize += fileSize
|
||||
}
|
||||
}
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
private func formatBytes(_ bytes: Int64) -> String {
|
||||
let formatter = ByteCountFormatter()
|
||||
formatter.allowedUnits = [.useGB]
|
||||
formatter.countStyle = .file
|
||||
return formatter.string(fromByteCount: bytes)
|
||||
}
|
||||
|
||||
// MARK: - 云端文件大小计算方法
|
||||
|
||||
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
||||
let url = try buildURL(
|
||||
repoPath: repoPath,
|
||||
path: "/repo/files",
|
||||
queryItems: [
|
||||
URLQueryItem(name: "Root", value: root),
|
||||
URLQueryItem(name: "Revision", value: revision)
|
||||
]
|
||||
)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(from: url)
|
||||
try validateResponse(response)
|
||||
|
||||
let modelResponse = try JSONDecoder().decode(ModelResponse.self, from: data)
|
||||
return modelResponse.data.files
|
||||
}
|
||||
|
||||
private func calculateTotalSize(files: [ModelFile], repoPath: String) async throws -> Int64 {
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
for file in files {
|
||||
if file.type == "tree" {
|
||||
let subFiles = try await fetchFileList(repoPath: repoPath, root: file.path, revision: "")
|
||||
totalSize += try await calculateTotalSize(files: subFiles, repoPath: repoPath)
|
||||
} else if file.type == "blob" {
|
||||
totalSize += Int64(file.size)
|
||||
}
|
||||
}
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
||||
var components = URLComponents()
|
||||
components.scheme = "https"
|
||||
components.host = "modelscope.cn"
|
||||
components.path = "/api/v1/models/\(repoPath)\(path)"
|
||||
components.queryItems = queryItems
|
||||
|
||||
guard let url = components.url else {
|
||||
throw ModelScopeError.invalidURL
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
private func validateResponse(_ response: URLResponse) throws {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200...299).contains(httpResponse.statusCode) else {
|
||||
throw ModelScopeError.invalidResponse
|
||||
}
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case modelName
|
||||
case tags
|
||||
case categories
|
||||
case size_gb
|
||||
case vendor
|
||||
case sources
|
||||
case tagTranslations
|
||||
case cachedSize
|
||||
}
|
||||
}
|
||||
|
||||
// 更新TagTranslationManager以支持单个标签翻译
|
||||
//extension TagTranslationManager {
|
||||
// func getTranslation(for tag: String) -> String? {
|
||||
// return globalTagTranslations[tag]
|
||||
// }
|
||||
//}
|
|
@ -1,296 +0,0 @@
|
|||
//
|
||||
// TBModelListViewModel.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import SwiftUI
|
||||
|
||||
@MainActor
|
||||
class TBModelListViewModel: ObservableObject {
|
||||
@Published var models: [TBModelInfo] = []
|
||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
||||
@Published private(set) var currentlyDownloading: String?
|
||||
@Published var showError = false
|
||||
@Published var errorMessage = ""
|
||||
@Published var searchText = ""
|
||||
@Published var quickFilterTags: [String] = []
|
||||
|
||||
@Published var selectedModel: TBModelInfo?
|
||||
|
||||
private let modelClient = TBModelClient()
|
||||
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
|
||||
|
||||
public var pinnedModelIds: [String] {
|
||||
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
|
||||
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
|
||||
}
|
||||
|
||||
// 获取所有可用的标签
|
||||
var allTags: [String] {
|
||||
let allTags = Set(models.flatMap { $0.tags })
|
||||
return Array(allTags)
|
||||
}
|
||||
|
||||
// 获取所有可用的分类
|
||||
var allCategories: [String] {
|
||||
let allCategories = Set(models.compactMap { $0.categories }.flatMap { $0 })
|
||||
return Array(allCategories)
|
||||
}
|
||||
|
||||
// 获取所有可用的厂商
|
||||
var allVendors: [String] {
|
||||
let allVendors = Set(models.compactMap { $0.vendor })
|
||||
return Array(allVendors)
|
||||
}
|
||||
|
||||
var filteredModels: [TBModelInfo] {
|
||||
let filteredModels = searchText.isEmpty ? models : models.filter { model in
|
||||
model.id.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||
}
|
||||
|
||||
let downloadedModels = filteredModels.filter { $0.isDownloaded }
|
||||
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded }
|
||||
|
||||
return downloadedModels + notDownloadedModels
|
||||
}
|
||||
|
||||
init() {
|
||||
Task {
|
||||
await fetchModels()
|
||||
}
|
||||
}
|
||||
|
||||
func fetchModels() async {
|
||||
do {
|
||||
let info = try await modelClient.getModelInfo()
|
||||
|
||||
self.quickFilterTags = info.quickFilterTags ?? []
|
||||
|
||||
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
|
||||
|
||||
var fetchedModels = info.models
|
||||
|
||||
self.filteDiffusionModel(fetchedModels: &fetchedModels)
|
||||
self.sortModels(fetchedModels: &fetchedModels)
|
||||
self.models = fetchedModels
|
||||
|
||||
// 异步获取未下载模型的大小信息
|
||||
Task {
|
||||
await fetchModelSizes(for: fetchedModels)
|
||||
}
|
||||
|
||||
} catch {
|
||||
showError = true
|
||||
errorMessage = "Error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchModelSizes(for models: [TBModelInfo]) async {
|
||||
await withTaskGroup(of: Void.self) { group in
|
||||
for (_, model) in models.enumerated() {
|
||||
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
|
||||
group.addTask {
|
||||
if let size = await model.fetchRemoteSize() {
|
||||
await MainActor.run {
|
||||
// 查找当前模型在实际数组中的索引
|
||||
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||
self.models[modelIndex].cachedSize = size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func recordModelUsage(modelName: String) {
|
||||
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
||||
if let index = models.firstIndex(where: { $0.modelName == modelName }) {
|
||||
models[index].lastUsedAt = Date()
|
||||
sortModels(fetchedModels: &models)
|
||||
}
|
||||
}
|
||||
|
||||
private func filteDiffusionModel(fetchedModels: inout [TBModelInfo]) {
|
||||
let hasDiffusionModels = fetchedModels.contains {
|
||||
$0.modelName.lowercased().contains("diffusion")
|
||||
}
|
||||
|
||||
if hasDiffusionModels {
|
||||
fetchedModels = fetchedModels.filter { model in
|
||||
let name = model.modelName.lowercased()
|
||||
let tags = model.tags.map { $0.lowercased() }
|
||||
|
||||
// only show gpu diffusion
|
||||
if name.contains("diffusion") {
|
||||
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..<fetchedModels.count {
|
||||
let model = fetchedModels[i]
|
||||
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
|
||||
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
|
||||
}
|
||||
}
|
||||
|
||||
private func sortModels(fetchedModels: inout [TBModelInfo]) {
|
||||
let pinned = pinnedModelIds
|
||||
|
||||
fetchedModels.sort { (model1, model2) -> Bool in
|
||||
let isPinned1 = pinned.contains(model1.id)
|
||||
let isPinned2 = pinned.contains(model2.id)
|
||||
let isDownloading1 = currentlyDownloading == model1.id
|
||||
let isDownloading2 = currentlyDownloading == model2.id
|
||||
|
||||
// 1. 正在下载的模型优先级最高
|
||||
if isDownloading1 != isDownloading2 {
|
||||
return isDownloading1
|
||||
}
|
||||
|
||||
// 2. 置顶的模型次优先级
|
||||
if isPinned1 != isPinned2 {
|
||||
return isPinned1
|
||||
}
|
||||
|
||||
// 3. 如果都是置顶的,按置顶时间排序
|
||||
if isPinned1 && isPinned2 {
|
||||
let index1 = pinned.firstIndex(of: model1.id)!
|
||||
let index2 = pinned.firstIndex(of: model2.id)!
|
||||
return index1 > index2 // Pinned later comes first
|
||||
}
|
||||
|
||||
// 4. 非置顶模型按下载状态排序
|
||||
if model1.isDownloaded != model2.isDownloaded {
|
||||
return model1.isDownloaded
|
||||
}
|
||||
|
||||
// 5. 如果都已下载,按最后使用时间排序
|
||||
if model1.isDownloaded {
|
||||
let date1 = model1.lastUsedAt ?? .distantPast
|
||||
let date2 = model2.lastUsedAt ?? .distantPast
|
||||
return date1 > date2
|
||||
}
|
||||
|
||||
return false // Keep original order for not-downloaded
|
||||
}
|
||||
}
|
||||
|
||||
func selectModel(_ model: TBModelInfo) {
|
||||
if model.isDownloaded {
|
||||
selectedModel = model
|
||||
} else {
|
||||
Task {
|
||||
await downloadModel(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func downloadModel(_ model: TBModelInfo) async {
|
||||
guard currentlyDownloading == nil else { return }
|
||||
|
||||
currentlyDownloading = model.id
|
||||
downloadProgress[model.id] = 0
|
||||
|
||||
do {
|
||||
try await modelClient.downloadModel(model: model) { progress in
|
||||
Task { @MainActor in
|
||||
self.downloadProgress[model.id] = progress
|
||||
}
|
||||
}
|
||||
|
||||
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
||||
models[index].isDownloaded = true
|
||||
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
|
||||
}
|
||||
|
||||
} catch {
|
||||
|
||||
if case ModelScopeError.downloadCancelled = error {
|
||||
print("Download was cancelled")
|
||||
} else {
|
||||
showError = true
|
||||
errorMessage = "Failed to download model: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
|
||||
currentlyDownloading = nil
|
||||
downloadProgress.removeValue(forKey: model.id)
|
||||
}
|
||||
|
||||
func cancelDownload() async {
|
||||
if let modelId = currentlyDownloading {
|
||||
await modelClient.cancelDownload()
|
||||
|
||||
downloadProgress.removeValue(forKey: modelId)
|
||||
currentlyDownloading = nil
|
||||
|
||||
print("Download cancelled for model: \(modelId)")
|
||||
}
|
||||
}
|
||||
|
||||
func pinModel(_ model: TBModelInfo) {
|
||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||
let pinned = models.remove(at: index)
|
||||
models.insert(pinned, at: 0)
|
||||
var ids = pinnedModelIds.filter { $0 != model.id }
|
||||
ids.append(model.id)
|
||||
pinnedModelIds = ids
|
||||
}
|
||||
|
||||
func unpinModel(_ model: TBModelInfo) {
|
||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||
let unpinned = models.remove(at: index)
|
||||
let insertIndex = models.count // 取消置顶后放到未置顶最后
|
||||
models.insert(unpinned, at: insertIndex)
|
||||
pinnedModelIds = pinnedModelIds.filter { $0 != model.id }
|
||||
}
|
||||
|
||||
func deleteModel(_ model: TBModelInfo) async {
|
||||
do {
|
||||
let fileManager = FileManager.default
|
||||
let modelPath = URL.init(filePath: model.localPath)
|
||||
|
||||
if let files = try? fileManager.contentsOfDirectory(
|
||||
at: modelPath,
|
||||
includingPropertiesForKeys: nil,
|
||||
options: [.skipsHiddenFiles]
|
||||
) {
|
||||
let storage = ModelDownloadStorage()
|
||||
for file in files {
|
||||
storage.clearFileStatus(at: file.path)
|
||||
}
|
||||
}
|
||||
|
||||
if fileManager.fileExists(atPath: modelPath.path) {
|
||||
try fileManager.removeItem(at: modelPath)
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
||||
models[index].isDownloaded = false
|
||||
ModelStorageManager.shared.clearDownloadStatus(for: model.modelName)
|
||||
}
|
||||
if selectedModel?.id == model.id {
|
||||
selectedModel = nil
|
||||
}
|
||||
}
|
||||
|
||||
} catch {
|
||||
print("Error deleting model: \(error)")
|
||||
await MainActor.run {
|
||||
self.errorMessage = "Failed to delete model: \(error.localizedDescription)"
|
||||
self.showError = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
// ModelClient.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Hub
|
||||
|
@ -26,17 +26,41 @@ class ModelClient {
|
|||
|
||||
init() {}
|
||||
|
||||
func getModelInfo() async throws -> TBDataResponse {
|
||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
return mockResponse
|
||||
}
|
||||
|
||||
func getModelList() async throws -> [ModelInfo] {
|
||||
let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
|
||||
return try await performRequest(url: url, retries: maxRetries)
|
||||
// TODO: get json from network
|
||||
// let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
|
||||
// return try await performRequest(url: url, retries: maxRetries)
|
||||
|
||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
|
||||
// 加载全局标签翻译
|
||||
TagTranslationManager.shared.loadTagTranslations(mockResponse.tagTranslations)
|
||||
|
||||
return mockResponse.models
|
||||
}
|
||||
|
||||
func getRepoInfo(repoName: String, revision: String) async throws -> RepoInfo {
|
||||
let url = URL(string: "\(baseURLString)/api/models/\(repoName)")!
|
||||
return try await performRequest(url: url, retries: maxRetries)
|
||||
}
|
||||
|
||||
@MainActor
|
||||
/**
|
||||
* Downloads a model from the selected source with progress tracking
|
||||
*
|
||||
* @param model The ModelInfo object containing model details
|
||||
* @param progress Progress callback that receives download progress (0.0 to 1.0)
|
||||
* @throws Various network or file system errors
|
||||
*/
|
||||
func downloadModel(model: ModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
switch ModelSourceManager.shared.selectedSource {
|
||||
|
@ -47,7 +71,9 @@ class ModelClient {
|
|||
}
|
||||
}
|
||||
|
||||
@MainActor
|
||||
/**
|
||||
* Cancels the current download operation
|
||||
*/
|
||||
func cancelDownload() async {
|
||||
if let manager = currentDownloadManager {
|
||||
await manager.cancelDownload()
|
||||
|
@ -55,10 +81,16 @@ class ModelClient {
|
|||
print("Download cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads model from ModelScope platform
|
||||
*
|
||||
* @param model The ModelInfo object to download
|
||||
* @param progress Progress callback for download updates
|
||||
* @throws Download or network related errors
|
||||
*/
|
||||
private func downloadFromModelScope(_ model: ModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
let ModelScopeId = model.modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
|
||||
let ModelScopeId = model.id
|
||||
let config = URLSessionConfiguration.default
|
||||
config.timeoutIntervalForRequest = 30
|
||||
config.timeoutIntervalForResource = 300
|
||||
|
@ -66,56 +98,68 @@ class ModelClient {
|
|||
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
|
||||
currentDownloadManager = manager
|
||||
|
||||
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.name) { fileProgress in
|
||||
progress(fileProgress)
|
||||
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
|
||||
Task { @MainActor in
|
||||
progress(fileProgress)
|
||||
}
|
||||
}
|
||||
|
||||
currentDownloadManager = nil
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads model from HuggingFace platform with optimized progress updates
|
||||
*
|
||||
* This method implements throttling to prevent UI stuttering by limiting
|
||||
* progress update frequency and filtering out minor progress changes.
|
||||
*
|
||||
* @param model The ModelInfo object to download
|
||||
* @param progress Progress callback for download updates
|
||||
* @throws Download or network related errors
|
||||
*/
|
||||
private func downloadFromHuggingFace(_ model: ModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
let repo = Hub.Repo(id: model.modelId)
|
||||
let repo = Hub.Repo(id: model.id)
|
||||
let modelFiles = ["*.*"]
|
||||
let mirrorHubApi = HubApi(endpoint: baseURL)
|
||||
|
||||
// Progress throttling mechanism to prevent UI stuttering
|
||||
var lastUpdateTime = Date()
|
||||
var lastProgress: Double = 0.0
|
||||
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
|
||||
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
|
||||
|
||||
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
|
||||
progress(fileProgress.fractionCompleted)
|
||||
}
|
||||
}
|
||||
let currentProgress = fileProgress.fractionCompleted
|
||||
let currentTime = Date()
|
||||
|
||||
private func performRequest<T: Decodable>(url: URL, retries: Int = 3) async throws -> T {
|
||||
var lastError: Error?
|
||||
// Check if progress should be updated
|
||||
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
|
||||
let progressDiff = abs(currentProgress - lastProgress)
|
||||
|
||||
for attempt in 1...retries {
|
||||
do {
|
||||
var request = URLRequest(url: url)
|
||||
request.setValue("application/json", forHTTPHeaderField: "Accept")
|
||||
// Update progress if any of these conditions are met:
|
||||
// 1. Time interval exceeds threshold
|
||||
// 2. Progress change exceeds threshold
|
||||
// 3. Progress reaches 100% (download complete)
|
||||
// 4. Progress is 0% (download start)
|
||||
if timeDiff >= progressUpdateInterval ||
|
||||
progressDiff >= progressThreshold ||
|
||||
currentProgress >= 1.0 ||
|
||||
currentProgress == 0.0 {
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
lastUpdateTime = currentTime
|
||||
lastProgress = currentProgress
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw NetworkError.invalidResponse
|
||||
}
|
||||
|
||||
if httpResponse.statusCode == 200 {
|
||||
return try JSONDecoder().decode(T.self, from: data)
|
||||
}
|
||||
|
||||
throw NetworkError.invalidResponse
|
||||
|
||||
} catch {
|
||||
lastError = error
|
||||
if attempt < retries {
|
||||
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(attempt)) * 1_000_000_000))
|
||||
continue
|
||||
// Ensure progress updates are executed on the main thread
|
||||
Task { @MainActor in
|
||||
progress(currentProgress)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError ?? NetworkError.unknown
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
enum NetworkError: Error {
|
||||
case invalidResponse
|
||||
case invalidData
|
||||
|
|
|
@ -1,176 +0,0 @@
|
|||
//
|
||||
// TBModelClient.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import Hub
|
||||
import Foundation
|
||||
|
||||
class TBModelClient {
|
||||
private let baseMirrorURL = "https://hf-mirror.com"
|
||||
private let baseURL = "https://huggingface.co"
|
||||
private let maxRetries = 5
|
||||
|
||||
private var currentDownloadManager: ModelScopeDownloadManager?
|
||||
|
||||
private lazy var baseURLString: String = {
|
||||
switch ModelSourceManager.shared.selectedSource {
|
||||
case .huggingFace:
|
||||
return baseURL
|
||||
default:
|
||||
return baseMirrorURL
|
||||
}
|
||||
}()
|
||||
|
||||
init() {}
|
||||
|
||||
func getModelInfo() async throws -> TBDataResponse {
|
||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
return mockResponse
|
||||
}
|
||||
|
||||
func getModelList() async throws -> [TBModelInfo] {
|
||||
// TODO: get json from network
|
||||
// let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
|
||||
// return try await performRequest(url: url, retries: maxRetries)
|
||||
|
||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||
throw NetworkError.invalidData
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||
|
||||
// 加载全局标签翻译
|
||||
TagTranslationManager.shared.loadTagTranslations(mockResponse.tagTranslations)
|
||||
|
||||
return mockResponse.models
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads a model from the selected source with progress tracking
|
||||
*
|
||||
* @param model The ModelInfo object containing model details
|
||||
* @param progress Progress callback that receives download progress (0.0 to 1.0)
|
||||
* @throws Various network or file system errors
|
||||
*/
|
||||
func downloadModel(model: TBModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
switch ModelSourceManager.shared.selectedSource {
|
||||
case .modelScope, .modeler:
|
||||
try await downloadFromModelScope(model, progress: progress)
|
||||
case .huggingFace:
|
||||
try await downloadFromHuggingFace(model, progress: progress)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancels the current download operation
|
||||
*/
|
||||
func cancelDownload() async {
|
||||
if let manager = currentDownloadManager {
|
||||
await manager.cancelDownload()
|
||||
currentDownloadManager = nil
|
||||
print("Download cancelled")
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Downloads model from ModelScope platform
|
||||
*
|
||||
* @param model The ModelInfo object to download
|
||||
* @param progress Progress callback for download updates
|
||||
* @throws Download or network related errors
|
||||
*/
|
||||
private func downloadFromModelScope(_ model: TBModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
let ModelScopeId = model.id
|
||||
let config = URLSessionConfiguration.default
|
||||
config.timeoutIntervalForRequest = 30
|
||||
config.timeoutIntervalForResource = 300
|
||||
|
||||
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
|
||||
currentDownloadManager = manager
|
||||
|
||||
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
|
||||
Task { @MainActor in
|
||||
progress(fileProgress)
|
||||
}
|
||||
}
|
||||
|
||||
currentDownloadManager = nil
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads model from HuggingFace platform with optimized progress updates
|
||||
*
|
||||
* This method implements throttling to prevent UI stuttering by limiting
|
||||
* progress update frequency and filtering out minor progress changes.
|
||||
*
|
||||
* @param model The ModelInfo object to download
|
||||
* @param progress Progress callback for download updates
|
||||
* @throws Download or network related errors
|
||||
*/
|
||||
private func downloadFromHuggingFace(_ model: TBModelInfo,
|
||||
progress: @escaping (Double) -> Void) async throws {
|
||||
let repo = Hub.Repo(id: model.id)
|
||||
let modelFiles = ["*.*"]
|
||||
let mirrorHubApi = HubApi(endpoint: baseURL)
|
||||
|
||||
// Progress throttling mechanism to prevent UI stuttering
|
||||
var lastUpdateTime = Date()
|
||||
var lastProgress: Double = 0.0
|
||||
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
|
||||
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
|
||||
|
||||
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
|
||||
let currentProgress = fileProgress.fractionCompleted
|
||||
let currentTime = Date()
|
||||
|
||||
// Check if progress should be updated
|
||||
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
|
||||
let progressDiff = abs(currentProgress - lastProgress)
|
||||
|
||||
// Update progress if any of these conditions are met:
|
||||
// 1. Time interval exceeds threshold
|
||||
// 2. Progress change exceeds threshold
|
||||
// 3. Progress reaches 100% (download complete)
|
||||
// 4. Progress is 0% (download start)
|
||||
if timeDiff >= progressUpdateInterval ||
|
||||
progressDiff >= progressThreshold ||
|
||||
currentProgress >= 1.0 ||
|
||||
currentProgress == 0.0 {
|
||||
|
||||
lastUpdateTime = currentTime
|
||||
lastProgress = currentProgress
|
||||
|
||||
// Ensure progress updates are executed on the main thread
|
||||
Task { @MainActor in
|
||||
progress(currentProgress)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TBDataResponse: Codable {
|
||||
let tagTranslations: [String: String]
|
||||
let quickFilterTags: [String]?
|
||||
let models: [TBModelInfo]
|
||||
let metadata: MockMetadata?
|
||||
|
||||
struct MockMetadata: Codable {
|
||||
let version: String
|
||||
let lastUpdated: String
|
||||
let schemaVersion: String
|
||||
let totalModels: Int
|
||||
let supportedPlatforms: [String]
|
||||
let minAppVersion: String
|
||||
}
|
||||
}
|
|
@ -10,7 +10,7 @@ import SwiftUI
|
|||
// MARK: - 筛选菜单视图
|
||||
struct FilterMenuView: View {
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
@StateObject private var viewModel = TBModelListViewModel()
|
||||
@StateObject private var viewModel = ModelListViewModel()
|
||||
@Binding var selectedTags: Set<String>
|
||||
@Binding var selectedCategories: Set<String>
|
||||
@Binding var selectedVendors: Set<String>
|
||||
|
|
|
@ -9,7 +9,7 @@ import SwiftUI
|
|||
|
||||
// MARK: - 工具栏视图
|
||||
struct ToolbarView: View {
|
||||
@ObservedObject var viewModel: TBModelListViewModel
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
@Binding var selectedSource: ModelSource
|
||||
@Binding var showSourceMenu: Bool
|
||||
@Binding var selectedTags: Set<String>
|
||||
|
|
|
@ -2,136 +2,146 @@
|
|||
// ModelListView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
|
||||
struct ModelListView: View {
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
||||
@State private var scrollOffset: CGFloat = 0
|
||||
@State private var showHelp = false
|
||||
@State private var showUserGuide = false
|
||||
|
||||
@State private var downloadSources: ModelSource?
|
||||
@State private var searchText = ""
|
||||
@State private var selectedSource = ModelSourceManager.shared.selectedSource
|
||||
|
||||
@State private var showOptions = false
|
||||
@State private var buttonFrame: CGRect = .zero
|
||||
@State private var showSourceMenu = false
|
||||
@State private var selectedTags: Set<String> = []
|
||||
@State private var selectedCategories: Set<String> = []
|
||||
@State private var selectedVendors: Set<String> = []
|
||||
@State private var showFilterMenu = false
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
VStack {
|
||||
HStack {
|
||||
Button {
|
||||
showOptions.toggle()
|
||||
} label: {
|
||||
HStack {
|
||||
Text("下载源:")
|
||||
.font(.system(size: 12, weight: .regular))
|
||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
||||
Text(selectedSource.rawValue)
|
||||
.font(.system(size: 12, weight: .regular))
|
||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
||||
Image(systemName: "chevron.down")
|
||||
.frame(width: 10, height: 10, alignment: .leading)
|
||||
.scaledToFit()
|
||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
||||
}
|
||||
.padding(.leading)
|
||||
}
|
||||
Spacer()
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
|
||||
Section {
|
||||
modelListSection
|
||||
} header: {
|
||||
toolbarSection
|
||||
}
|
||||
.frame(maxWidth: .infinity, maxHeight: 20)
|
||||
.background(
|
||||
GeometryReader { geometry in
|
||||
Color.white.onAppear {
|
||||
buttonFrame = geometry.frame(in: .global)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
List {
|
||||
SearchBar(text: $viewModel.searchText)
|
||||
.listRowInsets(EdgeInsets())
|
||||
.listRowSeparator(.hidden)
|
||||
.padding(.horizontal)
|
||||
|
||||
ForEach(viewModel.filteredModels, id: \.modelId) { model in
|
||||
|
||||
ModelRowView(model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: viewModel.downloadProgress[model.modelId] ?? 0,
|
||||
isDownloading: viewModel.currentlyDownloading == model.modelId,
|
||||
isOtherDownloading: viewModel.currentlyDownloading != nil) {
|
||||
if model.isDownloaded {
|
||||
viewModel.selectModel(model)
|
||||
} else {
|
||||
Task {
|
||||
await viewModel.downloadModel(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
.listRowSeparator(.hidden)
|
||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.modelId) ? Color.black.opacity(0.05) : Color.clear)
|
||||
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
||||
SwipeActionsView(model: model, viewModel: viewModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
.listStyle(.plain)
|
||||
.sheet(isPresented: $showHelp) {
|
||||
HelpView()
|
||||
}
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
.alert("Error", isPresented: $viewModel.showError) {
|
||||
Button("OK", role: .cancel) {}
|
||||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
.onAppear {
|
||||
checkFirstLaunch()
|
||||
}
|
||||
.alert(isPresented: $showUserGuide) {
|
||||
Alert(
|
||||
title: Text("User Guide"),
|
||||
message: Text("""
|
||||
This is a local large model application that requires certain performance from your device.
|
||||
It is recommended to choose different model sizes based on your device's memory.
|
||||
|
||||
The model recommendations for iPhone are as follows:
|
||||
- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).
|
||||
- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).
|
||||
- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).
|
||||
|
||||
Choosing a model that is too large may cause insufficient memory and crashes.
|
||||
"""),
|
||||
dismissButton: .default(Text("OK"))
|
||||
)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
.searchable(text: $searchText, prompt: "Search models...")
|
||||
.onChange(of: searchText) { _, newValue in
|
||||
viewModel.searchText = newValue
|
||||
}
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
.alert("Error", isPresented: $viewModel.showError) {
|
||||
Button("OK") { }
|
||||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if showOptions {
|
||||
CustomPopupMenu(isPresented: $showOptions,
|
||||
selectedSource: $selectedSource,
|
||||
anchorFrame: buttonFrame)
|
||||
// Extract model list section as independent view
|
||||
@ViewBuilder
|
||||
private var modelListSection: some View {
|
||||
LazyVStack(spacing: 8) {
|
||||
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
|
||||
modelRowView(model: model, index: index)
|
||||
|
||||
if index < filteredModels.count - 1 {
|
||||
Divider()
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
|
||||
// Extract toolbar section as independent view
|
||||
@ViewBuilder
|
||||
private var toolbarSection: some View {
|
||||
ToolbarView(
|
||||
viewModel: viewModel, selectedSource: $selectedSource,
|
||||
showSourceMenu: $showSourceMenu,
|
||||
selectedTags: $selectedTags,
|
||||
selectedCategories: $selectedCategories,
|
||||
selectedVendors: $selectedVendors,
|
||||
quickFilterTags: viewModel.quickFilterTags,
|
||||
showFilterMenu: $showFilterMenu,
|
||||
onSourceChange: handleSourceChange
|
||||
)
|
||||
}
|
||||
|
||||
// Extract single model row view as independent method
|
||||
@ViewBuilder
|
||||
private func modelRowView(model: ModelInfo, index: Int) -> some View {
|
||||
ModelRowView(
|
||||
model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
|
||||
isDownloading: viewModel.currentlyDownloading == model.id,
|
||||
isOtherDownloading: isOtherDownloadingCheck(model: model)
|
||||
) {
|
||||
Task {
|
||||
await viewModel.downloadModel(model)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
|
||||
// Extract complex boolean logic as independent method
|
||||
private func isOtherDownloadingCheck(model: ModelInfo) -> Bool {
|
||||
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
|
||||
}
|
||||
|
||||
// Extract source change handling logic as independent method
|
||||
private func handleSourceChange(_ source: ModelSource) {
|
||||
ModelSourceManager.shared.updateSelectedSource(source)
|
||||
selectedSource = source
|
||||
Task {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
}
|
||||
|
||||
// Filter models based on selected tags, categories and vendors
|
||||
private var filteredModels: [ModelInfo] {
|
||||
let baseFiltered = viewModel.filteredModels
|
||||
|
||||
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
|
||||
return baseFiltered
|
||||
}
|
||||
|
||||
return baseFiltered.filter { model in
|
||||
let tagMatch = checkTagMatch(model: model)
|
||||
let categoryMatch = checkCategoryMatch(model: model)
|
||||
let vendorMatch = checkVendorMatch(model: model)
|
||||
|
||||
return tagMatch && categoryMatch && vendorMatch
|
||||
}
|
||||
}
|
||||
|
||||
// Extract tag matching logic as independent method
|
||||
private func checkTagMatch(model: ModelInfo) -> Bool {
|
||||
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
|
||||
model.localizedTags.contains { tag in
|
||||
tag.localizedCaseInsensitiveContains(selectedTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func checkFirstLaunch() {
|
||||
let hasLaunchedBefore = UserDefaults.standard.bool(forKey: "hasLaunchedBefore")
|
||||
if !hasLaunchedBefore {
|
||||
// Show the user guide alert
|
||||
showUserGuide = true
|
||||
// Set the flag to true so it doesn't show again
|
||||
UserDefaults.standard.set(true, forKey: "hasLaunchedBefore")
|
||||
// Extract category matching logic as independent method
|
||||
private func checkCategoryMatch(model: ModelInfo) -> Bool {
|
||||
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
|
||||
model.categories?.contains { category in
|
||||
category.localizedCaseInsensitiveContains(selectedCategory)
|
||||
} ?? false
|
||||
}
|
||||
}
|
||||
|
||||
// Extract vendor matching logic as independent method
|
||||
private func checkVendorMatch(model: ModelInfo) -> Bool {
|
||||
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
|
||||
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,8 +9,8 @@ import SwiftUI
|
|||
|
||||
// MARK: - 操作按钮视图
|
||||
struct ActionButtonsView: View {
|
||||
let model: TBModelInfo
|
||||
@ObservedObject var viewModel: TBModelListViewModel
|
||||
let model: ModelInfo
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
let downloadProgress: Double
|
||||
let isDownloading: Bool
|
||||
let isOtherDownloading: Bool
|
||||
|
|
|
@ -9,7 +9,7 @@ import SwiftUI
|
|||
|
||||
// MARK: - 下载中按钮视图
|
||||
struct DownloadingButtonView: View {
|
||||
@ObservedObject var viewModel: TBModelListViewModel
|
||||
@ObservedObject var viewModel: ModelListViewModel
|
||||
let downloadProgress: Double
|
||||
|
||||
var body: some View {
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// ModelRowView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
@ -17,127 +17,78 @@ struct ModelRowView: View {
|
|||
let isOtherDownloading: Bool
|
||||
let onDownload: () -> Void
|
||||
|
||||
|
||||
@State private var showDeleteAlert = false
|
||||
|
||||
// 预计算本地化标签,避免重复计算
|
||||
private var localizedTags: [String] {
|
||||
model.localizedTags
|
||||
}
|
||||
|
||||
// 预计算格式化大小,避免重复计算
|
||||
private var formattedSize: String {
|
||||
model.formattedSize
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .top) {
|
||||
ModelIconView(modelId: model.modelId)
|
||||
HStack(alignment: .top, spacing: 0) {
|
||||
// 模型图标
|
||||
ModelIconView(modelId: model.id)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
VStack(alignment: .leading, spacing: 5) {
|
||||
Text(model.name)
|
||||
// 主要信息区域
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
// 模型名称
|
||||
Text(model.modelName)
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
.lineLimit(1)
|
||||
|
||||
if let lastUsedAt = model.lastUsedAt {
|
||||
Text("Last used: \(lastUsedAt.formatAgo())")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
|
||||
if !model.tags.isEmpty {
|
||||
ScrollView(.horizontal, showsIndicators: false) {
|
||||
HStack {
|
||||
ForEach(model.tags, id: \.self) { tag in
|
||||
Text(tag)
|
||||
.fontWeight(.regular)
|
||||
.font(.caption)
|
||||
.foregroundColor(.gray)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 3)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.gray.opacity(0.5), lineWidth: 0.5)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
.frame(height: 25)
|
||||
// 标签列表
|
||||
if !localizedTags.isEmpty {
|
||||
TagsView(tags: localizedTags)
|
||||
}
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
|
||||
Spacer()
|
||||
|
||||
VStack(alignment: .center, spacing: 4) {
|
||||
if model.isDownloaded {
|
||||
Button(action: {
|
||||
showDeleteAlert = true
|
||||
}) {
|
||||
Image(systemName: "trash")
|
||||
.fontWeight(.regular)
|
||||
.foregroundColor(.black.opacity(0.8))
|
||||
.frame(width: 20, height: 20)
|
||||
|
||||
Text("已下载")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.gray)
|
||||
.padding(.top, 4)
|
||||
}
|
||||
} else {
|
||||
if isDownloading {
|
||||
Button(action: {
|
||||
Task {
|
||||
await viewModel.cancelDownload()
|
||||
}
|
||||
}) {
|
||||
ProgressView(value: downloadProgress)
|
||||
.progressViewStyle(CircularProgressViewStyle())
|
||||
.frame(width: 28, height: 28)
|
||||
Text(String(format: "%.2f%%", downloadProgress * 100))
|
||||
.font(.caption2)
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
} else {
|
||||
Button(action: onDownload) {
|
||||
Image(systemName: "arrow.down.circle.fill")
|
||||
.font(.title2)
|
||||
}
|
||||
.foregroundColor(isOtherDownloading ? .gray : .primaryPurple)
|
||||
.disabled(isOtherDownloading)
|
||||
|
||||
HStack(alignment: .bottom, spacing: 2) {
|
||||
Image(systemName: "folder")
|
||||
.font(.caption2)
|
||||
Text(model.formattedSize)
|
||||
.font(.caption2)
|
||||
.lineLimit(1)
|
||||
.minimumScaleFactor(0.8)
|
||||
.offset(y: 1)
|
||||
.onAppear {
|
||||
if !model.isDownloaded && model.cachedSize == nil {
|
||||
Task {
|
||||
if let size = await model.fetchRemoteSize() {
|
||||
await MainActor.run {
|
||||
if let index = viewModel.models.firstIndex(where: { $0.modelId == model.modelId }) {
|
||||
viewModel.models[index].cachedSize = size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.frame(width: 60)
|
||||
ActionButtonsView(
|
||||
model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: downloadProgress,
|
||||
isDownloading: isDownloading,
|
||||
isOtherDownloading: isOtherDownloading,
|
||||
formattedSize: formattedSize,
|
||||
onDownload: onDownload,
|
||||
showDeleteAlert: $showDeleteAlert
|
||||
)
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
.alert(isPresented: $showDeleteAlert) {
|
||||
Alert(
|
||||
title: Text("确认删除"),
|
||||
message: Text("是否确认删除该模型?"),
|
||||
primaryButton: .destructive(Text("删除")) {
|
||||
Task {
|
||||
await viewModel.deleteModel(model)
|
||||
}
|
||||
},
|
||||
secondaryButton: .cancel(Text("取消"))
|
||||
)
|
||||
.contentShape(Rectangle()) // 确保整个区域都可以点击
|
||||
.onTapGesture {
|
||||
handleRowTap()
|
||||
}
|
||||
.alert("确认删除", isPresented: $showDeleteAlert) {
|
||||
Button("删除", role: .destructive) {
|
||||
Task {
|
||||
await viewModel.deleteModel(model)
|
||||
}
|
||||
}
|
||||
Button("取消", role: .cancel) { }
|
||||
} message: {
|
||||
Text("是否确认删除该模型?")
|
||||
}
|
||||
}
|
||||
|
||||
private func handleRowTap() {
|
||||
if model.isDownloaded {
|
||||
return
|
||||
} else if isDownloading {
|
||||
Task {
|
||||
await viewModel.cancelDownload()
|
||||
}
|
||||
} else if !isOtherDownloading {
|
||||
onDownload()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ struct SwipeActionsView: View {
|
|||
@ObservedObject var viewModel: ModelListViewModel
|
||||
|
||||
var body: some View {
|
||||
if viewModel.pinnedModelIds.contains(model.modelId) {
|
||||
if viewModel.pinnedModelIds.contains(model.id) {
|
||||
Button {
|
||||
viewModel.unpinModel(model)
|
||||
} label: {
|
||||
|
|
|
@ -1,147 +0,0 @@
|
|||
//
|
||||
// TBModelListView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct TBModelListView: View {
|
||||
@ObservedObject var viewModel: TBModelListViewModel
|
||||
@State private var searchText = ""
|
||||
@State private var selectedSource = ModelSourceManager.shared.selectedSource
|
||||
@State private var showSourceMenu = false
|
||||
@State private var selectedTags: Set<String> = []
|
||||
@State private var selectedCategories: Set<String> = []
|
||||
@State private var selectedVendors: Set<String> = []
|
||||
@State private var showFilterMenu = false
|
||||
|
||||
var body: some View {
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
|
||||
Section {
|
||||
modelListSection
|
||||
} header: {
|
||||
toolbarSection
|
||||
}
|
||||
}
|
||||
}
|
||||
.searchable(text: $searchText, prompt: "Search models...")
|
||||
.onChange(of: searchText) { _, newValue in
|
||||
viewModel.searchText = newValue
|
||||
}
|
||||
.refreshable {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
.alert("Error", isPresented: $viewModel.showError) {
|
||||
Button("OK") { }
|
||||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract model list section as independent view
|
||||
@ViewBuilder
|
||||
private var modelListSection: some View {
|
||||
LazyVStack(spacing: 8) {
|
||||
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
|
||||
modelRowView(model: model, index: index)
|
||||
|
||||
if index < filteredModels.count - 1 {
|
||||
Divider()
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
|
||||
// Extract toolbar section as independent view
|
||||
@ViewBuilder
|
||||
private var toolbarSection: some View {
|
||||
ToolbarView(
|
||||
viewModel: viewModel, selectedSource: $selectedSource,
|
||||
showSourceMenu: $showSourceMenu,
|
||||
selectedTags: $selectedTags,
|
||||
selectedCategories: $selectedCategories,
|
||||
selectedVendors: $selectedVendors,
|
||||
quickFilterTags: viewModel.quickFilterTags,
|
||||
showFilterMenu: $showFilterMenu,
|
||||
onSourceChange: handleSourceChange
|
||||
)
|
||||
}
|
||||
|
||||
// Extract single model row view as independent method
|
||||
@ViewBuilder
|
||||
private func modelRowView(model: TBModelInfo, index: Int) -> some View {
|
||||
TBModelRowView(
|
||||
model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
|
||||
isDownloading: viewModel.currentlyDownloading == model.id,
|
||||
isOtherDownloading: isOtherDownloadingCheck(model: model)
|
||||
) {
|
||||
Task {
|
||||
await viewModel.downloadModel(model)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
}
|
||||
|
||||
// Extract complex boolean logic as independent method
|
||||
private func isOtherDownloadingCheck(model: TBModelInfo) -> Bool {
|
||||
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
|
||||
}
|
||||
|
||||
// Extract source change handling logic as independent method
|
||||
private func handleSourceChange(_ source: ModelSource) {
|
||||
ModelSourceManager.shared.updateSelectedSource(source)
|
||||
selectedSource = source
|
||||
Task {
|
||||
await viewModel.fetchModels()
|
||||
}
|
||||
}
|
||||
|
||||
// Filter models based on selected tags, categories and vendors
|
||||
private var filteredModels: [TBModelInfo] {
|
||||
let baseFiltered = viewModel.filteredModels
|
||||
|
||||
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
|
||||
return baseFiltered
|
||||
}
|
||||
|
||||
return baseFiltered.filter { model in
|
||||
let tagMatch = checkTagMatch(model: model)
|
||||
let categoryMatch = checkCategoryMatch(model: model)
|
||||
let vendorMatch = checkVendorMatch(model: model)
|
||||
|
||||
return tagMatch && categoryMatch && vendorMatch
|
||||
}
|
||||
}
|
||||
|
||||
// Extract tag matching logic as independent method
|
||||
private func checkTagMatch(model: TBModelInfo) -> Bool {
|
||||
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
|
||||
model.localizedTags.contains { tag in
|
||||
tag.localizedCaseInsensitiveContains(selectedTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract category matching logic as independent method
|
||||
private func checkCategoryMatch(model: TBModelInfo) -> Bool {
|
||||
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
|
||||
model.categories?.contains { category in
|
||||
category.localizedCaseInsensitiveContains(selectedCategory)
|
||||
} ?? false
|
||||
}
|
||||
}
|
||||
|
||||
// Extract vendor matching logic as independent method
|
||||
private func checkVendorMatch(model: TBModelInfo) -> Bool {
|
||||
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
|
||||
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
//
|
||||
// TBModelRowView.swift
|
||||
// MNNLLMiOS
|
||||
//
|
||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
struct TBModelRowView: View {
|
||||
|
||||
let model: TBModelInfo
|
||||
@ObservedObject var viewModel: TBModelListViewModel
|
||||
|
||||
let downloadProgress: Double
|
||||
let isDownloading: Bool
|
||||
let isOtherDownloading: Bool
|
||||
let onDownload: () -> Void
|
||||
|
||||
@State private var showDeleteAlert = false
|
||||
|
||||
// 预计算本地化标签,避免重复计算
|
||||
private var localizedTags: [String] {
|
||||
model.localizedTags
|
||||
}
|
||||
|
||||
// 预计算格式化大小,避免重复计算
|
||||
private var formattedSize: String {
|
||||
model.formattedSize
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .top, spacing: 0) {
|
||||
// 模型图标
|
||||
ModelIconView(modelId: model.id)
|
||||
.frame(width: 40, height: 40)
|
||||
|
||||
// 主要信息区域
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
// 模型名称
|
||||
Text(model.modelName)
|
||||
.font(.headline)
|
||||
.fontWeight(.semibold)
|
||||
.lineLimit(1)
|
||||
|
||||
// 标签列表
|
||||
if !localizedTags.isEmpty {
|
||||
TagsView(tags: localizedTags)
|
||||
}
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
|
||||
Spacer()
|
||||
|
||||
ActionButtonsView(
|
||||
model: model,
|
||||
viewModel: viewModel,
|
||||
downloadProgress: downloadProgress,
|
||||
isDownloading: isDownloading,
|
||||
isOtherDownloading: isOtherDownloading,
|
||||
formattedSize: formattedSize,
|
||||
onDownload: onDownload,
|
||||
showDeleteAlert: $showDeleteAlert
|
||||
)
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
.contentShape(Rectangle()) // 确保整个区域都可以点击
|
||||
.onTapGesture {
|
||||
handleRowTap()
|
||||
}
|
||||
.alert("确认删除", isPresented: $showDeleteAlert) {
|
||||
Button("删除", role: .destructive) {
|
||||
Task {
|
||||
await viewModel.deleteModel(model)
|
||||
}
|
||||
}
|
||||
Button("取消", role: .cancel) { }
|
||||
} message: {
|
||||
Text("是否确认删除该模型?")
|
||||
}
|
||||
}
|
||||
|
||||
private func handleRowTap() {
|
||||
if model.isDownloaded {
|
||||
return
|
||||
} else if isDownloading {
|
||||
Task {
|
||||
await viewModel.cancelDownload()
|
||||
}
|
||||
} else if !isOtherDownloading {
|
||||
onDownload()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,6 +19,9 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"Audio Message" : {
|
||||
|
||||
},
|
||||
"Benchmark" : {
|
||||
|
||||
|
@ -202,9 +205,6 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"Last used: %@" : {
|
||||
|
||||
},
|
||||
"Model Configuration" : {
|
||||
"localizations" : {
|
||||
|
@ -365,11 +365,9 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"TB模型" : {
|
||||
|
||||
},
|
||||
"This is a local large model application that requires certain performance from your device.\nIt is recommended to choose different model sizes based on your device's memory. \n\nThe model recommendations for iPhone are as follows:\n- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).\n- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).\n- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).\n\nChoosing a model that is too large may cause insufficient memory and crashes." : {
|
||||
"extractionState" : "stale",
|
||||
"localizations" : {
|
||||
"en" : {
|
||||
"stringUnit" : {
|
||||
|
@ -426,6 +424,7 @@
|
|||
}
|
||||
},
|
||||
"User Guide" : {
|
||||
"extractionState" : "stale",
|
||||
"localizations" : {
|
||||
"zh-Hans" : {
|
||||
"stringUnit" : {
|
||||
|
@ -525,9 +524,6 @@
|
|||
},
|
||||
"语言" : {
|
||||
|
||||
},
|
||||
"错误" : {
|
||||
|
||||
}
|
||||
},
|
||||
"version" : "1.0"
|
||||
|
|
Loading…
Reference in New Issue