[feat] change data source

This commit is contained in:
游薪渝(揽清) 2025-07-10 11:16:25 +08:00
parent 3d2091bc24
commit 85855f7dbc
25 changed files with 648 additions and 1386 deletions

View File

@ -102,7 +102,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") { PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1] var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
if let isDeepSeek = self?.modelInfo.name.lowercased().contains("deepseek"), isDeepSeek == true, if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
let text = self?.processor.process(progress: message.text) { let text = self?.processor.process(progress: message.text) {
updateLastMsg?.text = text updateLastMsg?.text = text
} else { } else {

View File

@ -25,13 +25,13 @@ final class LLMChatData {
self.assistant = LLMChatUser( self.assistant = LLMChatUser(
uid: "2", uid: "2",
name: modelInfo.name, name: modelInfo.modelName,
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png") avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
) )
self.system = LLMChatUser( self.system = LLMChatUser(
uid: "0", uid: "0",
name: modelInfo.name, name: modelInfo.modelName,
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png") avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
) )
} }

View File

@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
let modelConfigManager: ModelConfigManager let modelConfigManager: ModelConfigManager
var isDiffusionModel: Bool { var isDiffusionModel: Bool {
return modelInfo.name.lowercased().contains("diffusion") return modelInfo.modelName.lowercased().contains("diffusion")
} }
init(modelInfo: ModelInfo, history: ChatHistory? = nil) { init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
), userType: .system) ), userType: .system)
} }
if modelInfo.name.lowercased().contains("diffusion") { if modelInfo.modelName.lowercased().contains("diffusion") {
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
Task { @MainActor in Task { @MainActor in
print("Diffusion Model \(success)") print("Diffusion Model \(success)")
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
func sendToLLM(draft: DraftMessage) { func sendToLLM(draft: DraftMessage) {
self.send(draft: draft, userType: .user) self.send(draft: draft, userType: .user)
if isModelLoaded { if isModelLoaded {
if modelInfo.name.lowercased().contains("diffusion") { if modelInfo.modelName.lowercased().contains("diffusion") {
self.getDiffusionResponse(draft: draft) self.getDiffusionResponse(draft: draft)
} else { } else {
self.getLLMRespsonse(draft: draft) self.getLLMRespsonse(draft: draft)
@ -284,7 +284,7 @@ final class LLMChatViewModel: ObservableObject {
} }
private func convertDeepSeekMutliChat(content: String) -> String { private func convertDeepSeekMutliChat(content: String) -> String {
if self.modelInfo.name.lowercased().contains("deepseek") { if self.modelInfo.modelName.lowercased().contains("deepseek") {
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|> /* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|> <|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
*/ */
@ -337,7 +337,7 @@ final class LLMChatViewModel: ObservableObject {
ChatHistoryManager.shared.saveChat( ChatHistoryManager.shared.saveChat(
historyId: historyId, historyId: historyId,
modelId: modelInfo.modelId, modelId: modelInfo.modelId,
modelName: modelInfo.name, modelName: modelInfo.modelName,
messages: messages messages: messages
) )
@ -387,4 +387,4 @@ final class LLMChatViewModel: ObservableObject {
print("Error accessing tmp directory: \(error.localizedDescription)") print("Error accessing tmp directory: \(error.localizedDescription)")
} }
} }
} }

View File

@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
let modelConfigManager: ModelConfigManager let modelConfigManager: ModelConfigManager
var isDiffusionModel: Bool { var isDiffusionModel: Bool {
return modelInfo.name.lowercased().contains("diffusion") return modelInfo.modelName.lowercased().contains("diffusion")
} }
init(modelInfo: ModelInfo, history: ChatHistory? = nil) { init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
), userType: .system) ), userType: .system)
} }
if modelInfo.name.lowercased().contains("diffusion") { if modelInfo.modelName.lowercased().contains("diffusion") {
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
Task { @MainActor in Task { @MainActor in
print("Diffusion Model \(success)") print("Diffusion Model \(success)")
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
func sendToLLM(draft: DraftMessage) { func sendToLLM(draft: DraftMessage) {
self.send(draft: draft, userType: .user) self.send(draft: draft, userType: .user)
if isModelLoaded { if isModelLoaded {
if modelInfo.name.lowercased().contains("diffusion") { if modelInfo.modelName.lowercased().contains("diffusion") {
self.getDiffusionResponse(draft: draft) self.getDiffusionResponse(draft: draft)
} else { } else {
self.getLLMRespsonse(draft: draft) self.getLLMRespsonse(draft: draft)
@ -298,7 +298,7 @@ final class LLMChatViewModel: ObservableObject {
} }
private func convertDeepSeekMutliChat(content: String) -> String { private func convertDeepSeekMutliChat(content: String) -> String {
if self.modelInfo.name.lowercased().contains("deepseek") { if self.modelInfo.modelName.lowercased().contains("deepseek") {
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|> /* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|> <|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
*/ */
@ -350,8 +350,8 @@ final class LLMChatViewModel: ObservableObject {
func onStop() { func onStop() {
ChatHistoryManager.shared.saveChat( ChatHistoryManager.shared.saveChat(
historyId: historyId, historyId: historyId,
modelId: modelInfo.modelId, modelId: modelInfo.id,
modelName: modelInfo.name, modelName: modelInfo.modelName,
messages: messages messages: messages
) )

View File

@ -24,7 +24,7 @@ struct LLMChatView: View {
@State private var showSettings = false @State private var showSettings = false
init(modelInfo: ModelInfo, history: ChatHistory? = nil) { init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
self.title = modelInfo.name self.title = modelInfo.modelName
self.modelPath = modelInfo.localPath self.modelPath = modelInfo.localPath
let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history) let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history)
_viewModel = StateObject(wrappedValue: viewModel) _viewModel = StateObject(wrappedValue: viewModel)

View File

@ -12,14 +12,14 @@ struct LocalModelListView: View {
var body: some View { var body: some View {
List { List {
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.modelId) { model in ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.id) { model in
Button(action: { Button(action: {
viewModel.selectModel(model) viewModel.selectModel(model)
}) { }) {
LocalModelRowView(model: model) LocalModelRowView(model: model)
} }
.listRowSeparator(.hidden) .listRowSeparator(.hidden)
.listRowBackground(viewModel.pinnedModelIds.contains(model.modelId) ? Color.black.opacity(0.05) : Color.clear) .listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
.swipeActions(edge: .trailing, allowsFullSwipe: false) { .swipeActions(edge: .trailing, allowsFullSwipe: false) {
SwipeActionsView(model: model, viewModel: viewModel) SwipeActionsView(model: model, viewModel: viewModel)
} }
@ -35,4 +35,4 @@ struct LocalModelListView: View {
Text(viewModel.errorMessage) Text(viewModel.errorMessage)
} }
} }
} }

View File

@ -11,36 +11,28 @@ struct LocalModelRowView: View {
let model: ModelInfo let model: ModelInfo
private var localizedTags: [String] {
model.localizedTags
}
private var formattedSize: String {
model.formattedSize
}
var body: some View { var body: some View {
HStack(alignment: .center) { HStack(alignment: .center) {
ModelIconView(modelId: model.modelId) ModelIconView(modelId: model.id)
.frame(width: 50, height: 50) .frame(width: 50, height: 50)
VStack(alignment: .leading, spacing: 8) { VStack(alignment: .leading, spacing: 8) {
Text(model.name) Text(model.modelName)
.font(.headline) .font(.headline)
.fontWeight(.semibold) .fontWeight(.semibold)
.lineLimit(1) .lineLimit(1)
if !model.tags.isEmpty { if !localizedTags.isEmpty {
ScrollView(.horizontal, showsIndicators: false) { TagsView(tags: localizedTags)
HStack {
ForEach(model.tags, id: \.self) { tag in
Text(tag)
.fontWeight(.regular)
.font(.caption)
.foregroundColor(Color(red: 151/255, green: 151/255, blue: 151/255))
.padding(.horizontal, 10)
.padding(.vertical, 4)
.background(
RoundedRectangle(cornerRadius: 10)
.stroke(Color(red: 151/255, green: 151/255, blue: 151/255), lineWidth: 0.5)
.padding(1)
)
}
}
}
} }
HStack { HStack {
@ -51,7 +43,7 @@ struct LocalModelRowView: View {
.foregroundColor(.gray) .foregroundColor(.gray)
.frame(width: 20, height: 20) .frame(width: 20, height: 20)
Text(model.formattedSize) Text(formattedSize)
.font(.caption) .font(.caption)
.fontWeight(.medium) .fontWeight(.medium)
.foregroundColor(.gray) .foregroundColor(.gray)

View File

@ -16,39 +16,32 @@ struct MainTabView: View {
@State private var showWebView = false @State private var showWebView = false
@State private var webViewURL: URL? @State private var webViewURL: URL?
@State private var navigateToSettings = false @State private var navigateToSettings = false
@StateObject private var modelListViewModel = TBModelListViewModel() @StateObject private var modelListViewModel = ModelListViewModel()
@StateObject private var localModelListViewModel = ModelListViewModel()
@State private var selectedTab: Int = 0 @State private var selectedTab: Int = 0
@State private var titles = ["本地模型", "模型市场", "TB模型", "Benchmark"] @State private var titles = ["本地模型", "模型市场", "Benchmark"]
var body: some View { var body: some View {
ZStack { ZStack {
NavigationView { NavigationView {
TabView(selection: $selectedTab) { TabView(selection: $selectedTab) {
LocalModelListView(viewModel: localModelListViewModel) LocalModelListView(viewModel: modelListViewModel)
.tabItem { .tabItem {
Image(systemName: "house.fill") Image(systemName: "house.fill")
Text("本地模型") Text("本地模型")
} }
.tag(0) .tag(0)
ModelListView(viewModel: localModelListViewModel) ModelListView(viewModel: modelListViewModel)
.tabItem { .tabItem {
Image(systemName: "cart.fill") Image(systemName: "doc.text.fill")
Text("模型市场") Text("模型市场")
} }
.tag(1) .tag(1)
TBModelListView(viewModel: modelListViewModel)
.tabItem {
Image(systemName: "doc.text.fill")
Text("TB模型")
}
.tag(2)
BenchmarkView() BenchmarkView()
.tabItem { .tabItem {
Image(systemName: "clock.fill") Image(systemName: "clock.fill")
Text("Benchmark") Text("Benchmark")
} }
.tag(3) .tag(2)
} }
.background( .background(
ZStack { ZStack {
@ -130,19 +123,13 @@ struct MainTabView: View {
@ViewBuilder @ViewBuilder
private var chatDestination: some View { private var chatDestination: some View {
if let model = localModelListViewModel.selectedModel { if let model = modelListViewModel.selectedModel {
LLMChatView(modelInfo: model) LLMChatView(modelInfo: model)
.navigationBarHidden(false) .navigationBarHidden(false)
.navigationBarTitleDisplayMode(.inline) .navigationBarTitleDisplayMode(.inline)
.toolbar(.hidden, for: .tabBar) // Hide tab bar in chat .toolbar(.hidden, for: .tabBar) // Hide tab bar in chat
} else if let history = selectedHistory { } else if let history = selectedHistory {
let modelInfo = ModelInfo( let modelInfo = ModelInfo(modelId: history.modelId, isDownloaded: true)
modelId: history.modelId,
createdAt: "",
downloads: 0,
tags: [],
isDownloaded: true
)
LLMChatView(modelInfo: modelInfo, history: history) LLMChatView(modelInfo: modelInfo, history: history)
.navigationBarHidden(false) .navigationBarHidden(false)
.navigationBarTitleDisplayMode(.inline) .navigationBarTitleDisplayMode(.inline)
@ -155,17 +142,17 @@ struct MainTabView: View {
private var chatIsActiveBinding: Binding<Bool> { private var chatIsActiveBinding: Binding<Bool> {
Binding<Bool>( Binding<Bool>(
get: { get: {
return localModelListViewModel.selectedModel != nil || selectedHistory != nil return modelListViewModel.selectedModel != nil || selectedHistory != nil
}, },
set: { isActive in set: { isActive in
if !isActive { if !isActive {
// Record usage when returning from chat // Record usage when returning from chat
if let model = localModelListViewModel.selectedModel { if let model = modelListViewModel.selectedModel {
localModelListViewModel.recordModelUsage(modelName: model.name) modelListViewModel.recordModelUsage(modelName: model.modelName)
} }
// Clear selections // Clear selections
localModelListViewModel.selectedModel = nil modelListViewModel.selectedModel = nil
selectedHistory = nil selectedHistory = nil
} }
} }
@ -195,4 +182,4 @@ struct MainTabView: View {
UITabBar.appearance().standardAppearance = appearance UITabBar.appearance().standardAppearance = appearance
UITabBar.appearance().scrollEdgeAppearance = appearance UITabBar.appearance().scrollEdgeAppearance = appearance
} }
} }

View File

@ -1,83 +1,108 @@
// //
// ModelClient.swift // TBModelInfo.swift
// MNNLLMiOS // MNNLLMiOS
// //
// Created by () on 2025/1/3. // Created by () on 2025/7/4.
// //
import Hub import Hub
import Foundation import Foundation
struct ModelInfo: Codable { struct ModelInfo: Codable {
let modelId: String // MARK: - Properties
let createdAt: String let modelName: String
let downloads: Int
let tags: [String] let tags: [String]
let categories: [String]?
let size_gb: Double?
let vendor: String?
let sources: [String: String]?
let tagTranslations: [String: [String]]?
var name: String { // Runtime properties
modelId.removingTaobaoPrefix()
}
var isDownloaded: Bool = false var isDownloaded: Bool = false
var lastUsedAt: Date? var lastUsedAt: Date?
var cachedSize: Int64? = nil var cachedSize: Int64? = nil
var localPath: String { // MARK: - Initialization
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelId)).path
init(modelName: String = "",
tags: [String] = [],
categories: [String]? = nil,
size_gb: Double? = nil,
vendor: String? = nil,
sources: [String: String]? = nil,
tagTranslations: [String: [String]]? = nil,
isDownloaded: Bool = false,
lastUsedAt: Date? = nil,
cachedSize: Int64? = nil) {
self.modelName = modelName
self.tags = tags
self.categories = categories
self.size_gb = size_gb
self.vendor = vendor
self.sources = sources
self.tagTranslations = tagTranslations
self.isDownloaded = isDownloaded
self.lastUsedAt = lastUsedAt
self.cachedSize = cachedSize
} }
init(modelId: String, isDownloaded: Bool = true) {
let modelName = modelId.components(separatedBy: "/").last ?? modelId
self.init(
modelName: modelName,
tags: [],
sources: ["huggingface": modelId],
isDownloaded: isDownloaded
)
}
// MARK: - Model Identity & Localization
var id: String {
guard let sources = sources else {
return "taobao-mnn/\(modelName)"
}
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
}
var localizedTags: [String] {
let currentLanguage = LanguageManager.shared.currentLanguage
let isChineseLanguage = currentLanguage == "简体中文"
if isChineseLanguage, let translations = tagTranslations {
let languageCode = "zh-Hans"
return translations[languageCode] ?? tags
} else {
return tags
}
}
// MARK: - File System & Path Management
var localPath: String {
let modelScopeId = "taobao-mnn/\(modelName)"
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
}
// MARK: - Size Calculation & Formatting
var formattedSize: String { var formattedSize: String {
if isDownloaded { if let cached = cachedSize {
return formatLocalSize()
} else if let cached = cachedSize {
return formatBytes(cached) return formatBytes(cached)
} else if isDownloaded {
return formatLocalSize()
} else if let sizeGb = size_gb {
return String(format: "%.1f GB", sizeGb)
} else { } else {
return "计算中..." return "Calculating..."
} }
} }
func fetchRemoteSize() async -> Int64? {
let modelScopeId = modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
do {
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
return totalSize
} catch {
print("Error fetching remote size for \(modelId): \(error)")
return nil
}
}
private func formatLocalSize() -> String {
let path = localPath
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
do {
let totalSize = try calculateDirectorySize(at: path)
return formatBytes(totalSize)
} catch {
return "未知"
}
}
private func calculateDirectorySize(at path: String) throws -> Int64 {
let fileManager = FileManager.default
var totalSize: Int64 = 0
let enumerator = fileManager.enumerator(atPath: path)
while let fileName = enumerator?.nextObject() as? String {
let filePath = (path as NSString).appendingPathComponent(fileName)
let attributes = try fileManager.attributesOfItem(atPath: filePath)
if let fileSize = attributes[.size] as? Int64 {
totalSize += fileSize
}
}
return totalSize
}
private func formatBytes(_ bytes: Int64) -> String { private func formatBytes(_ bytes: Int64) -> String {
let formatter = ByteCountFormatter() let formatter = ByteCountFormatter()
formatter.allowedUnits = [.useGB] formatter.allowedUnits = [.useGB]
@ -85,7 +110,102 @@ struct ModelInfo: Codable {
return formatter.string(fromByteCount: bytes) return formatter.string(fromByteCount: bytes)
} }
// MARK: - // MARK: - Local Size Calculation
private func formatLocalSize() -> String {
let path = localPath
guard FileManager.default.fileExists(atPath: path) else { return "Unknown" }
do {
let totalSize = try calculateDirectorySize(at: path)
return formatBytes(totalSize)
} catch {
return "Unknown"
}
}
private func calculateDirectorySize(at path: String) throws -> Int64 {
let fileManager = FileManager.default
var totalSize: Int64 = 0
print("Calculating directory size for path: \(path)")
let directoryURL = URL(fileURLWithPath: path)
guard fileManager.fileExists(atPath: path) else {
print("Path does not exist: \(path)")
return 0
}
let resourceKeys: [URLResourceKey] = [.isRegularFileKey, .totalFileAllocatedSizeKey, .fileSizeKey, .nameKey]
let enumerator = fileManager.enumerator(
at: directoryURL,
includingPropertiesForKeys: resourceKeys,
options: [.skipsHiddenFiles, .skipsPackageDescendants],
errorHandler: { (url, error) -> Bool in
print("Error accessing \(url): \(error)")
return true
}
)
guard let fileEnumerator = enumerator else {
throw NSError(domain: "FileEnumerationError", code: -1,
userInfo: [NSLocalizedDescriptionKey: "Failed to create file enumerator"])
}
var fileCount = 0
for case let fileURL as URL in fileEnumerator {
do {
let resourceValues = try fileURL.resourceValues(forKeys: Set(resourceKeys))
guard let isRegularFile = resourceValues.isRegularFile, isRegularFile else { continue }
let fileName = resourceValues.name ?? "Unknown"
fileCount += 1
// Use actual disk allocated size, fallback to logical size if not available
if let actualSize = resourceValues.totalFileAllocatedSize {
totalSize += Int64(actualSize)
if fileCount <= 10 {
let actualSizeGB = Double(actualSize) / (1024 * 1024 * 1024)
let logicalSizeGB = Double(resourceValues.fileSize ?? 0) / (1024 * 1024 * 1024)
print("File \(fileCount): \(fileName) - Logical: \(String(format: "%.3f", logicalSizeGB)) GB, Actual: \(String(format: "%.3f", actualSizeGB)) GB")
}
} else if let logicalSize = resourceValues.fileSize {
totalSize += Int64(logicalSize)
if fileCount <= 10 {
let logicalSizeGB = Double(logicalSize) / (1024 * 1024 * 1024)
print("File \(fileCount): \(fileName) - Size: \(String(format: "%.3f", logicalSizeGB)) GB (fallback)")
}
}
} catch {
print("Error getting resource values for \(fileURL): \(error)")
continue
}
}
let totalSizeGB = Double(totalSize) / (1024 * 1024 * 1024)
print("Total files: \(fileCount), Total actual disk usage: \(String(format: "%.2f", totalSizeGB)) GB")
return totalSize
}
// MARK: - Remote Size Calculation
func fetchRemoteSize() async -> Int64? {
let modelScopeId = "taobao-mnn/\(modelName)"
do {
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
return totalSize
} catch {
print("Error fetching remote size for \(id): \(error)")
return nil
}
}
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] { private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
let url = try buildURL( let url = try buildURL(
@ -119,6 +239,8 @@ struct ModelInfo: Codable {
return totalSize return totalSize
} }
// MARK: - Network Utilities
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL { private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
var components = URLComponents() var components = URLComponents()
components.scheme = "https" components.scheme = "https"
@ -139,21 +261,9 @@ struct ModelInfo: Codable {
} }
} }
// MARK: - Codable
private enum CodingKeys: String, CodingKey { private enum CodingKeys: String, CodingKey {
case modelId case modelName, tags, categories, size_gb, vendor, sources, tagTranslations, cachedSize
case tags
case downloads
case createdAt
case cachedSize
}
}
struct RepoInfo: Codable {
let modelId: String
let sha: String
let siblings: [Sibling]
struct Sibling: Codable {
let rfilename: String
} }
} }

View File

@ -2,7 +2,7 @@
// ModelListViewModel.swift // ModelListViewModel.swift
// MNNLLMiOS // MNNLLMiOS
// //
// Created by () on 2025/1/3. // Created by () on 2025/7/4.
// //
import Foundation import Foundation
@ -10,74 +10,78 @@ import SwiftUI
@MainActor @MainActor
class ModelListViewModel: ObservableObject { class ModelListViewModel: ObservableObject {
// MARK: - Published Properties
@Published var models: [ModelInfo] = [] @Published var models: [ModelInfo] = []
@Published private(set) var downloadProgress: [String: Double] = [:] @Published var searchText = ""
@Published private(set) var currentlyDownloading: String? @Published var quickFilterTags: [String] = []
@Published var selectedModel: ModelInfo?
@Published var showError = false @Published var showError = false
@Published var errorMessage = "" @Published var errorMessage = ""
@Published var searchText = ""
@Published var selectedModel: ModelInfo? // Download state
@Published private(set) var downloadProgress: [String: Double] = [:]
@Published private(set) var currentlyDownloading: String?
// MARK: - Private Properties
private let modelClient = ModelClient() private let modelClient = ModelClient()
private let pinnedModelKey = "com.mnnllm.pinnedModelIds" private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
// MARK: - Model Data Access
public var pinnedModelIds: [String] { public var pinnedModelIds: [String] {
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] } get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) } set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
} }
var allTags: [String] {
Array(Set(models.flatMap { $0.tags }))
}
var allCategories: [String] {
Array(Set(models.compactMap { $0.categories }.flatMap { $0 }))
}
var allVendors: [String] {
Array(Set(models.compactMap { $0.vendor }))
}
var filteredModels: [ModelInfo] { var filteredModels: [ModelInfo] {
let filtered = searchText.isEmpty ? models : models.filter { model in
let filteredModels = searchText.isEmpty ? models : models.filter { model in model.id.localizedCaseInsensitiveContains(searchText) ||
model.modelId.localizedCaseInsensitiveContains(searchText) || model.modelName.localizedCaseInsensitiveContains(searchText) ||
model.tags.contains { $0.localizedCaseInsensitiveContains(searchText) } model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
} }
let downloadedModels = filteredModels.filter { $0.isDownloaded } let downloaded = filtered.filter { $0.isDownloaded }
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded } let notDownloaded = filtered.filter { !$0.isDownloaded }
return downloadedModels + notDownloadedModels return downloaded + notDownloaded
} }
// MARK: - Initialization
init() { init() {
Task { Task {
await fetchModels() await fetchModels()
} }
} }
// MARK: - Model Data Management
func fetchModels() async { func fetchModels() async {
do { do {
var fetchedModels = try await modelClient.getModelList() let info = try await modelClient.getModelInfo()
let hasDiffusionModels = fetchedModels.contains { self.quickFilterTags = info.quickFilterTags ?? []
$0.name.lowercased().contains("diffusion") TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
}
if hasDiffusionModels { var fetchedModels = info.models
fetchedModels = fetchedModels.filter { model in
let name = model.name.lowercased()
let tags = model.tags.map { $0.lowercased() }
// only show gpu diffusion
if name.contains("diffusion") {
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
}
return true
}
}
for i in 0..<fetchedModels.count { filterDiffusionModels(fetchedModels: &fetchedModels)
let model = fetchedModels[i]
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.name)
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.name)
}
// Sort models
sortModels(fetchedModels: &fetchedModels) sortModels(fetchedModels: &fetchedModels)
self.models = fetchedModels
// // Asynchronously fetch size info for undownloaded models
Task { Task {
await fetchModelSizes(for: fetchedModels) await fetchModelSizes(for: fetchedModels)
} }
@ -91,12 +95,12 @@ class ModelListViewModel: ObservableObject {
private func fetchModelSizes(for models: [ModelInfo]) async { private func fetchModelSizes(for models: [ModelInfo]) async {
await withTaskGroup(of: Void.self) { group in await withTaskGroup(of: Void.self) { group in
for (_, model) in models.enumerated() { for (_, model) in models.enumerated() {
if !model.isDownloaded && model.cachedSize == nil { if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
group.addTask { group.addTask {
if let size = await model.fetchRemoteSize() { if let size = await model.fetchRemoteSize() {
await MainActor.run { await MainActor.run {
// // Find current model index in actual array
if let modelIndex = self.models.firstIndex(where: { $0.modelId == model.modelId }) { if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[modelIndex].cachedSize = size self.models[modelIndex].cachedSize = size
} }
} }
@ -107,11 +111,29 @@ class ModelListViewModel: ObservableObject {
} }
} }
func recordModelUsage(modelName: String) { private func filterDiffusionModels(fetchedModels: inout [ModelInfo]) {
ModelStorageManager.shared.updateLastUsed(for: modelName) let hasDiffusionModels = fetchedModels.contains {
if let index = models.firstIndex(where: { $0.name == modelName }) { $0.modelName.lowercased().contains("diffusion")
models[index].lastUsedAt = Date() }
sortModels(fetchedModels: &models)
if hasDiffusionModels {
fetchedModels = fetchedModels.filter { model in
let name = model.modelName.lowercased()
let tags = model.tags.map { $0.lowercased() }
// Only show GPU diffusion models
if name.contains("diffusion") {
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
}
return true
}
}
for i in 0..<fetchedModels.count {
let model = fetchedModels[i]
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
} }
} }
@ -119,24 +141,34 @@ class ModelListViewModel: ObservableObject {
let pinned = pinnedModelIds let pinned = pinnedModelIds
fetchedModels.sort { (model1, model2) -> Bool in fetchedModels.sort { (model1, model2) -> Bool in
let isPinned1 = pinned.contains(model1.modelId) let isPinned1 = pinned.contains(model1.id)
let isPinned2 = pinned.contains(model2.modelId) let isPinned2 = pinned.contains(model2.id)
let isDownloading1 = currentlyDownloading == model1.id
let isDownloading2 = currentlyDownloading == model2.id
// 1. Currently downloading models have highest priority
if isDownloading1 != isDownloading2 {
return isDownloading1
}
// 2. Pinned models have second priority
if isPinned1 != isPinned2 { if isPinned1 != isPinned2 {
return isPinned1 return isPinned1
} }
// 3. If both are pinned, sort by pin time
if isPinned1 && isPinned2 { if isPinned1 && isPinned2 {
let index1 = pinned.firstIndex(of: model1.modelId)! let index1 = pinned.firstIndex(of: model1.id)!
let index2 = pinned.firstIndex(of: model2.modelId)! let index2 = pinned.firstIndex(of: model2.id)!
return index1 > index2 // Pinned later comes first return index1 > index2 // Pinned later comes first
} }
// Non-pinned models // 4. Non-pinned models sorted by download status
if model1.isDownloaded != model2.isDownloaded { if model1.isDownloaded != model2.isDownloaded {
return model1.isDownloaded return model1.isDownloaded
} }
// 5. If both downloaded, sort by last used time
if model1.isDownloaded { if model1.isDownloaded {
let date1 = model1.lastUsedAt ?? .distantPast let date1 = model1.lastUsedAt ?? .distantPast
let date2 = model2.lastUsedAt ?? .distantPast let date2 = model2.lastUsedAt ?? .distantPast
@ -145,10 +177,10 @@ class ModelListViewModel: ObservableObject {
return false // Keep original order for not-downloaded return false // Keep original order for not-downloaded
} }
models = fetchedModels
} }
// MARK: - Model Selection & Usage
func selectModel(_ model: ModelInfo) { func selectModel(_ model: ModelInfo) {
if model.isDownloaded { if model.isDownloaded {
selectedModel = model selectedModel = model
@ -159,26 +191,35 @@ class ModelListViewModel: ObservableObject {
} }
} }
func recordModelUsage(modelName: String) {
ModelStorageManager.shared.updateLastUsed(for: modelName)
if let index = models.firstIndex(where: { $0.modelName == modelName }) {
models[index].lastUsedAt = Date()
sortModels(fetchedModels: &models)
}
}
// MARK: - Download Management
func downloadModel(_ model: ModelInfo) async { func downloadModel(_ model: ModelInfo) async {
guard currentlyDownloading == nil else { return } guard currentlyDownloading == nil else { return }
currentlyDownloading = model.modelId currentlyDownloading = model.id
downloadProgress[model.modelId] = 0 downloadProgress[model.id] = 0
do { do {
try await modelClient.downloadModel(model: model) { progress in try await modelClient.downloadModel(model: model) { progress in
Task { @MainActor in Task { @MainActor in
self.downloadProgress[model.modelId] = progress self.downloadProgress[model.id] = progress
} }
} }
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) { if let index = models.firstIndex(where: { $0.id == model.id }) {
models[index].isDownloaded = true models[index].isDownloaded = true
ModelStorageManager.shared.markModelAsDownloaded(model.name) ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
} }
} catch { } catch {
if case ModelScopeError.downloadCancelled = error { if case ModelScopeError.downloadCancelled = error {
print("Download was cancelled") print("Download was cancelled")
} else { } else {
@ -188,7 +229,7 @@ class ModelListViewModel: ObservableObject {
} }
currentlyDownloading = nil currentlyDownloading = nil
downloadProgress.removeValue(forKey: model.modelId) downloadProgress.removeValue(forKey: model.id)
} }
func cancelDownload() async { func cancelDownload() async {
@ -201,24 +242,28 @@ class ModelListViewModel: ObservableObject {
print("Download cancelled for model: \(modelId)") print("Download cancelled for model: \(modelId)")
} }
} }
// MARK: - Pin Management
func pinModel(_ model: ModelInfo) { func pinModel(_ model: ModelInfo) {
guard let index = models.firstIndex(where: { $0.modelId == model.modelId }) else { return } guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
let pinned = models.remove(at: index) let pinned = models.remove(at: index)
models.insert(pinned, at: 0) models.insert(pinned, at: 0)
var ids = pinnedModelIds.filter { $0 != model.modelId } var ids = pinnedModelIds.filter { $0 != model.id }
ids.append(model.modelId) ids.append(model.id)
pinnedModelIds = ids pinnedModelIds = ids
} }
func unpinModel(_ model: ModelInfo) { func unpinModel(_ model: ModelInfo) {
guard let index = models.firstIndex(where: { $0.modelId == model.modelId }) else { return } guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
let unpinned = models.remove(at: index) let unpinned = models.remove(at: index)
let insertIndex = models.count // let insertIndex = models.count // Insert at end after unpinning
models.insert(unpinned, at: insertIndex) models.insert(unpinned, at: insertIndex)
pinnedModelIds = pinnedModelIds.filter { $0 != model.modelId } pinnedModelIds = pinnedModelIds.filter { $0 != model.id }
} }
// MARK: - Model Deletion
func deleteModel(_ model: ModelInfo) async { func deleteModel(_ model: ModelInfo) async {
do { do {
let fileManager = FileManager.default let fileManager = FileManager.default
@ -240,11 +285,11 @@ class ModelListViewModel: ObservableObject {
} }
await MainActor.run { await MainActor.run {
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) { if let index = models.firstIndex(where: { $0.id == model.id }) {
models[index].isDownloaded = false models[index].isDownloaded = false
ModelStorageManager.shared.clearDownloadStatus(for: model.name) ModelStorageManager.shared.clearDownloadStatus(for: model.modelName)
} }
if selectedModel?.modelId == model.modelId { if selectedModel?.id == model.id {
selectedModel = nil selectedModel = nil
} }
} }

View File

@ -0,0 +1,24 @@
//
// TBDataResponse.swift
// MNNLLMiOS
//
// Created by () on 2025/7/9.
//
import Foundation
struct TBDataResponse: Codable {
let tagTranslations: [String: String]
let quickFilterTags: [String]?
let models: [ModelInfo]
let metadata: Metadata?
struct Metadata: Codable {
let version: String
let lastUpdated: String
let schemaVersion: String
let totalModels: Int
let supportedPlatforms: [String]
let minAppVersion: String
}
}

View File

@ -1,184 +0,0 @@
//
// TBModelInfo.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Hub
import Foundation
struct TBModelInfo: Codable {
let modelName: String
let tags: [String]
let categories: [String]?
let size_gb: Double?
let vendor: String?
let sources: [String: String]?
let tagTranslations: [String: [String]]?
//
var isDownloaded: Bool = false
var lastUsedAt: Date?
var cachedSize: Int64? = nil
// IDmodelId
var id: String {
guard let sources = sources else {
return "taobao-mnn/\(modelName)"
}
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
}
// - 使tagTranslations
var localizedTags: [String] {
let currentLanguage = LanguageManager.shared.currentLanguage
let isChineseLanguage = currentLanguage == "简体中文"
if isChineseLanguage, let translations = tagTranslations {
let languageCode = "zh-Hans"
return translations[languageCode] ?? tags
} else {
return tags
}
}
var localPath: String {
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: id)).path
}
var formattedSize: String {
if isDownloaded {
return formatLocalSize()
} else if let cached = cachedSize {
return formatBytes(cached)
} else if let sizeGb = size_gb {
return String(format: "%.1f GB", sizeGb)
} else {
return "计算中..."
}
}
func fetchRemoteSize() async -> Int64? {
// TODO: now only support modelScope, support huggingFace later
let modelScopeId = "taobao-mnn/\(modelName)"
do {
let files = try await fetchFileList(repoPath: id, root: "", revision: "")
let totalSize = try await calculateTotalSize(files: files, repoPath: id)
return totalSize
} catch {
print("Error fetching remote size for \(id): \(error)")
return nil
}
}
private func formatLocalSize() -> String {
let path = localPath
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
do {
let totalSize = try calculateDirectorySize(at: path)
return formatBytes(totalSize)
} catch {
return "未知"
}
}
private func calculateDirectorySize(at path: String) throws -> Int64 {
let fileManager = FileManager.default
var totalSize: Int64 = 0
let enumerator = fileManager.enumerator(atPath: path)
while let fileName = enumerator?.nextObject() as? String {
let filePath = (path as NSString).appendingPathComponent(fileName)
let attributes = try fileManager.attributesOfItem(atPath: filePath)
if let fileSize = attributes[.size] as? Int64 {
totalSize += fileSize
}
}
return totalSize
}
private func formatBytes(_ bytes: Int64) -> String {
let formatter = ByteCountFormatter()
formatter.allowedUnits = [.useGB]
formatter.countStyle = .file
return formatter.string(fromByteCount: bytes)
}
// MARK: -
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
let url = try buildURL(
repoPath: repoPath,
path: "/repo/files",
queryItems: [
URLQueryItem(name: "Root", value: root),
URLQueryItem(name: "Revision", value: revision)
]
)
let (data, response) = try await URLSession.shared.data(from: url)
try validateResponse(response)
let modelResponse = try JSONDecoder().decode(ModelResponse.self, from: data)
return modelResponse.data.files
}
private func calculateTotalSize(files: [ModelFile], repoPath: String) async throws -> Int64 {
var totalSize: Int64 = 0
for file in files {
if file.type == "tree" {
let subFiles = try await fetchFileList(repoPath: repoPath, root: file.path, revision: "")
totalSize += try await calculateTotalSize(files: subFiles, repoPath: repoPath)
} else if file.type == "blob" {
totalSize += Int64(file.size)
}
}
return totalSize
}
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
var components = URLComponents()
components.scheme = "https"
components.host = "modelscope.cn"
components.path = "/api/v1/models/\(repoPath)\(path)"
components.queryItems = queryItems
guard let url = components.url else {
throw ModelScopeError.invalidURL
}
return url
}
private func validateResponse(_ response: URLResponse) throws {
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode) else {
throw ModelScopeError.invalidResponse
}
}
private enum CodingKeys: String, CodingKey {
case modelName
case tags
case categories
case size_gb
case vendor
case sources
case tagTranslations
case cachedSize
}
}
// TagTranslationManager
//extension TagTranslationManager {
// func getTranslation(for tag: String) -> String? {
// return globalTagTranslations[tag]
// }
//}

View File

@ -1,296 +0,0 @@
//
// TBModelListViewModel.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Foundation
import SwiftUI
@MainActor
class TBModelListViewModel: ObservableObject {
@Published var models: [TBModelInfo] = []
@Published private(set) var downloadProgress: [String: Double] = [:]
@Published private(set) var currentlyDownloading: String?
@Published var showError = false
@Published var errorMessage = ""
@Published var searchText = ""
@Published var quickFilterTags: [String] = []
@Published var selectedModel: TBModelInfo?
private let modelClient = TBModelClient()
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
public var pinnedModelIds: [String] {
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
}
//
var allTags: [String] {
let allTags = Set(models.flatMap { $0.tags })
return Array(allTags)
}
//
var allCategories: [String] {
let allCategories = Set(models.compactMap { $0.categories }.flatMap { $0 })
return Array(allCategories)
}
//
var allVendors: [String] {
let allVendors = Set(models.compactMap { $0.vendor })
return Array(allVendors)
}
var filteredModels: [TBModelInfo] {
let filteredModels = searchText.isEmpty ? models : models.filter { model in
model.id.localizedCaseInsensitiveContains(searchText) ||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
}
let downloadedModels = filteredModels.filter { $0.isDownloaded }
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded }
return downloadedModels + notDownloadedModels
}
init() {
Task {
await fetchModels()
}
}
func fetchModels() async {
do {
let info = try await modelClient.getModelInfo()
self.quickFilterTags = info.quickFilterTags ?? []
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
var fetchedModels = info.models
self.filteDiffusionModel(fetchedModels: &fetchedModels)
self.sortModels(fetchedModels: &fetchedModels)
self.models = fetchedModels
//
Task {
await fetchModelSizes(for: fetchedModels)
}
} catch {
showError = true
errorMessage = "Error: \(error.localizedDescription)"
}
}
private func fetchModelSizes(for models: [TBModelInfo]) async {
await withTaskGroup(of: Void.self) { group in
for (_, model) in models.enumerated() {
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
group.addTask {
if let size = await model.fetchRemoteSize() {
await MainActor.run {
//
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
self.models[modelIndex].cachedSize = size
}
}
}
}
}
}
}
}
func recordModelUsage(modelName: String) {
ModelStorageManager.shared.updateLastUsed(for: modelName)
if let index = models.firstIndex(where: { $0.modelName == modelName }) {
models[index].lastUsedAt = Date()
sortModels(fetchedModels: &models)
}
}
private func filteDiffusionModel(fetchedModels: inout [TBModelInfo]) {
let hasDiffusionModels = fetchedModels.contains {
$0.modelName.lowercased().contains("diffusion")
}
if hasDiffusionModels {
fetchedModels = fetchedModels.filter { model in
let name = model.modelName.lowercased()
let tags = model.tags.map { $0.lowercased() }
// only show gpu diffusion
if name.contains("diffusion") {
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
}
return true
}
}
for i in 0..<fetchedModels.count {
let model = fetchedModels[i]
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
}
}
private func sortModels(fetchedModels: inout [TBModelInfo]) {
let pinned = pinnedModelIds
fetchedModels.sort { (model1, model2) -> Bool in
let isPinned1 = pinned.contains(model1.id)
let isPinned2 = pinned.contains(model2.id)
let isDownloading1 = currentlyDownloading == model1.id
let isDownloading2 = currentlyDownloading == model2.id
// 1.
if isDownloading1 != isDownloading2 {
return isDownloading1
}
// 2.
if isPinned1 != isPinned2 {
return isPinned1
}
// 3.
if isPinned1 && isPinned2 {
let index1 = pinned.firstIndex(of: model1.id)!
let index2 = pinned.firstIndex(of: model2.id)!
return index1 > index2 // Pinned later comes first
}
// 4.
if model1.isDownloaded != model2.isDownloaded {
return model1.isDownloaded
}
// 5. 使
if model1.isDownloaded {
let date1 = model1.lastUsedAt ?? .distantPast
let date2 = model2.lastUsedAt ?? .distantPast
return date1 > date2
}
return false // Keep original order for not-downloaded
}
}
func selectModel(_ model: TBModelInfo) {
if model.isDownloaded {
selectedModel = model
} else {
Task {
await downloadModel(model)
}
}
}
func downloadModel(_ model: TBModelInfo) async {
guard currentlyDownloading == nil else { return }
currentlyDownloading = model.id
downloadProgress[model.id] = 0
do {
try await modelClient.downloadModel(model: model) { progress in
Task { @MainActor in
self.downloadProgress[model.id] = progress
}
}
if let index = models.firstIndex(where: { $0.id == model.id }) {
models[index].isDownloaded = true
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
}
} catch {
if case ModelScopeError.downloadCancelled = error {
print("Download was cancelled")
} else {
showError = true
errorMessage = "Failed to download model: \(error.localizedDescription)"
}
}
currentlyDownloading = nil
downloadProgress.removeValue(forKey: model.id)
}
func cancelDownload() async {
if let modelId = currentlyDownloading {
await modelClient.cancelDownload()
downloadProgress.removeValue(forKey: modelId)
currentlyDownloading = nil
print("Download cancelled for model: \(modelId)")
}
}
func pinModel(_ model: TBModelInfo) {
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
let pinned = models.remove(at: index)
models.insert(pinned, at: 0)
var ids = pinnedModelIds.filter { $0 != model.id }
ids.append(model.id)
pinnedModelIds = ids
}
func unpinModel(_ model: TBModelInfo) {
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
let unpinned = models.remove(at: index)
let insertIndex = models.count //
models.insert(unpinned, at: insertIndex)
pinnedModelIds = pinnedModelIds.filter { $0 != model.id }
}
func deleteModel(_ model: TBModelInfo) async {
do {
let fileManager = FileManager.default
let modelPath = URL.init(filePath: model.localPath)
if let files = try? fileManager.contentsOfDirectory(
at: modelPath,
includingPropertiesForKeys: nil,
options: [.skipsHiddenFiles]
) {
let storage = ModelDownloadStorage()
for file in files {
storage.clearFileStatus(at: file.path)
}
}
if fileManager.fileExists(atPath: modelPath.path) {
try fileManager.removeItem(at: modelPath)
}
await MainActor.run {
if let index = models.firstIndex(where: { $0.id == model.id }) {
models[index].isDownloaded = false
ModelStorageManager.shared.clearDownloadStatus(for: model.modelName)
}
if selectedModel?.id == model.id {
selectedModel = nil
}
}
} catch {
print("Error deleting model: \(error)")
await MainActor.run {
self.errorMessage = "Failed to delete model: \(error.localizedDescription)"
self.showError = true
}
}
}
}

View File

@ -2,7 +2,7 @@
// ModelClient.swift // ModelClient.swift
// MNNLLMiOS // MNNLLMiOS
// //
// Created by () on 2025/1/3. // Created by () on 2025/7/4.
// //
import Hub import Hub
@ -26,17 +26,41 @@ class ModelClient {
init() {} init() {}
func getModelList() async throws -> [ModelInfo] { func getModelInfo() async throws -> TBDataResponse {
let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")! guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
return try await performRequest(url: url, retries: maxRetries) throw NetworkError.invalidData
}
let data = try Data(contentsOf: url)
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
return mockResponse
} }
func getRepoInfo(repoName: String, revision: String) async throws -> RepoInfo { func getModelList() async throws -> [ModelInfo] {
let url = URL(string: "\(baseURLString)/api/models/\(repoName)")! // TODO: get json from network
return try await performRequest(url: url, retries: maxRetries) // let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
// return try await performRequest(url: url, retries: maxRetries)
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
throw NetworkError.invalidData
}
let data = try Data(contentsOf: url)
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
//
TagTranslationManager.shared.loadTagTranslations(mockResponse.tagTranslations)
return mockResponse.models
} }
@MainActor /**
* Downloads a model from the selected source with progress tracking
*
* @param model The ModelInfo object containing model details
* @param progress Progress callback that receives download progress (0.0 to 1.0)
* @throws Various network or file system errors
*/
func downloadModel(model: ModelInfo, func downloadModel(model: ModelInfo,
progress: @escaping (Double) -> Void) async throws { progress: @escaping (Double) -> Void) async throws {
switch ModelSourceManager.shared.selectedSource { switch ModelSourceManager.shared.selectedSource {
@ -47,7 +71,9 @@ class ModelClient {
} }
} }
@MainActor /**
* Cancels the current download operation
*/
func cancelDownload() async { func cancelDownload() async {
if let manager = currentDownloadManager { if let manager = currentDownloadManager {
await manager.cancelDownload() await manager.cancelDownload()
@ -55,10 +81,16 @@ class ModelClient {
print("Download cancelled") print("Download cancelled")
} }
} }
/**
* Downloads model from ModelScope platform
*
* @param model The ModelInfo object to download
* @param progress Progress callback for download updates
* @throws Download or network related errors
*/
private func downloadFromModelScope(_ model: ModelInfo, private func downloadFromModelScope(_ model: ModelInfo,
progress: @escaping (Double) -> Void) async throws { progress: @escaping (Double) -> Void) async throws {
let ModelScopeId = model.modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN") let ModelScopeId = model.id
let config = URLSessionConfiguration.default let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 30 config.timeoutIntervalForRequest = 30
config.timeoutIntervalForResource = 300 config.timeoutIntervalForResource = 300
@ -66,56 +98,68 @@ class ModelClient {
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource) let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
currentDownloadManager = manager currentDownloadManager = manager
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.name) { fileProgress in try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
progress(fileProgress) Task { @MainActor in
progress(fileProgress)
}
} }
currentDownloadManager = nil currentDownloadManager = nil
} }
/**
* Downloads model from HuggingFace platform with optimized progress updates
*
* This method implements throttling to prevent UI stuttering by limiting
* progress update frequency and filtering out minor progress changes.
*
* @param model The ModelInfo object to download
* @param progress Progress callback for download updates
* @throws Download or network related errors
*/
private func downloadFromHuggingFace(_ model: ModelInfo, private func downloadFromHuggingFace(_ model: ModelInfo,
progress: @escaping (Double) -> Void) async throws { progress: @escaping (Double) -> Void) async throws {
let repo = Hub.Repo(id: model.modelId) let repo = Hub.Repo(id: model.id)
let modelFiles = ["*.*"] let modelFiles = ["*.*"]
let mirrorHubApi = HubApi(endpoint: baseURL) let mirrorHubApi = HubApi(endpoint: baseURL)
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
progress(fileProgress.fractionCompleted)
}
}
private func performRequest<T: Decodable>(url: URL, retries: Int = 3) async throws -> T {
var lastError: Error?
for attempt in 1...retries { // Progress throttling mechanism to prevent UI stuttering
do { var lastUpdateTime = Date()
var request = URLRequest(url: url) var lastProgress: Double = 0.0
request.setValue("application/json", forHTTPHeaderField: "Accept") let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
let currentProgress = fileProgress.fractionCompleted
let currentTime = Date()
// Check if progress should be updated
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
let progressDiff = abs(currentProgress - lastProgress)
// Update progress if any of these conditions are met:
// 1. Time interval exceeds threshold
// 2. Progress change exceeds threshold
// 3. Progress reaches 100% (download complete)
// 4. Progress is 0% (download start)
if timeDiff >= progressUpdateInterval ||
progressDiff >= progressThreshold ||
currentProgress >= 1.0 ||
currentProgress == 0.0 {
let (data, response) = try await URLSession.shared.data(for: request) lastUpdateTime = currentTime
lastProgress = currentProgress
guard let httpResponse = response as? HTTPURLResponse else { // Ensure progress updates are executed on the main thread
throw NetworkError.invalidResponse Task { @MainActor in
} progress(currentProgress)
if httpResponse.statusCode == 200 {
return try JSONDecoder().decode(T.self, from: data)
}
throw NetworkError.invalidResponse
} catch {
lastError = error
if attempt < retries {
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(attempt)) * 1_000_000_000))
continue
} }
} }
} }
throw lastError ?? NetworkError.unknown
} }
} }
enum NetworkError: Error { enum NetworkError: Error {
case invalidResponse case invalidResponse
case invalidData case invalidData

View File

@ -1,176 +0,0 @@
//
// TBModelClient.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import Hub
import Foundation
class TBModelClient {
private let baseMirrorURL = "https://hf-mirror.com"
private let baseURL = "https://huggingface.co"
private let maxRetries = 5
private var currentDownloadManager: ModelScopeDownloadManager?
private lazy var baseURLString: String = {
switch ModelSourceManager.shared.selectedSource {
case .huggingFace:
return baseURL
default:
return baseMirrorURL
}
}()
init() {}
func getModelInfo() async throws -> TBDataResponse {
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
throw NetworkError.invalidData
}
let data = try Data(contentsOf: url)
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
return mockResponse
}
func getModelList() async throws -> [TBModelInfo] {
// TODO: get json from network
// let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
// return try await performRequest(url: url, retries: maxRetries)
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
throw NetworkError.invalidData
}
let data = try Data(contentsOf: url)
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
//
TagTranslationManager.shared.loadTagTranslations(mockResponse.tagTranslations)
return mockResponse.models
}
/**
* Downloads a model from the selected source with progress tracking
*
* @param model The ModelInfo object containing model details
* @param progress Progress callback that receives download progress (0.0 to 1.0)
* @throws Various network or file system errors
*/
func downloadModel(model: TBModelInfo,
progress: @escaping (Double) -> Void) async throws {
switch ModelSourceManager.shared.selectedSource {
case .modelScope, .modeler:
try await downloadFromModelScope(model, progress: progress)
case .huggingFace:
try await downloadFromHuggingFace(model, progress: progress)
}
}
/**
* Cancels the current download operation
*/
func cancelDownload() async {
if let manager = currentDownloadManager {
await manager.cancelDownload()
currentDownloadManager = nil
print("Download cancelled")
}
}
/**
* Downloads model from ModelScope platform
*
* @param model The ModelInfo object to download
* @param progress Progress callback for download updates
* @throws Download or network related errors
*/
private func downloadFromModelScope(_ model: TBModelInfo,
progress: @escaping (Double) -> Void) async throws {
let ModelScopeId = model.id
let config = URLSessionConfiguration.default
config.timeoutIntervalForRequest = 30
config.timeoutIntervalForResource = 300
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
currentDownloadManager = manager
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
Task { @MainActor in
progress(fileProgress)
}
}
currentDownloadManager = nil
}
/**
* Downloads model from HuggingFace platform with optimized progress updates
*
* This method implements throttling to prevent UI stuttering by limiting
* progress update frequency and filtering out minor progress changes.
*
* @param model The ModelInfo object to download
* @param progress Progress callback for download updates
* @throws Download or network related errors
*/
private func downloadFromHuggingFace(_ model: TBModelInfo,
progress: @escaping (Double) -> Void) async throws {
let repo = Hub.Repo(id: model.id)
let modelFiles = ["*.*"]
let mirrorHubApi = HubApi(endpoint: baseURL)
// Progress throttling mechanism to prevent UI stuttering
var lastUpdateTime = Date()
var lastProgress: Double = 0.0
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
let currentProgress = fileProgress.fractionCompleted
let currentTime = Date()
// Check if progress should be updated
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
let progressDiff = abs(currentProgress - lastProgress)
// Update progress if any of these conditions are met:
// 1. Time interval exceeds threshold
// 2. Progress change exceeds threshold
// 3. Progress reaches 100% (download complete)
// 4. Progress is 0% (download start)
if timeDiff >= progressUpdateInterval ||
progressDiff >= progressThreshold ||
currentProgress >= 1.0 ||
currentProgress == 0.0 {
lastUpdateTime = currentTime
lastProgress = currentProgress
// Ensure progress updates are executed on the main thread
Task { @MainActor in
progress(currentProgress)
}
}
}
}
}
struct TBDataResponse: Codable {
let tagTranslations: [String: String]
let quickFilterTags: [String]?
let models: [TBModelInfo]
let metadata: MockMetadata?
struct MockMetadata: Codable {
let version: String
let lastUpdated: String
let schemaVersion: String
let totalModels: Int
let supportedPlatforms: [String]
let minAppVersion: String
}
}

View File

@ -10,7 +10,7 @@ import SwiftUI
// MARK: - // MARK: -
struct FilterMenuView: View { struct FilterMenuView: View {
@Environment(\.dismiss) private var dismiss @Environment(\.dismiss) private var dismiss
@StateObject private var viewModel = TBModelListViewModel() @StateObject private var viewModel = ModelListViewModel()
@Binding var selectedTags: Set<String> @Binding var selectedTags: Set<String>
@Binding var selectedCategories: Set<String> @Binding var selectedCategories: Set<String>
@Binding var selectedVendors: Set<String> @Binding var selectedVendors: Set<String>

View File

@ -9,7 +9,7 @@ import SwiftUI
// MARK: - // MARK: -
struct ToolbarView: View { struct ToolbarView: View {
@ObservedObject var viewModel: TBModelListViewModel @ObservedObject var viewModel: ModelListViewModel
@Binding var selectedSource: ModelSource @Binding var selectedSource: ModelSource
@Binding var showSourceMenu: Bool @Binding var showSourceMenu: Bool
@Binding var selectedTags: Set<String> @Binding var selectedTags: Set<String>

View File

@ -2,136 +2,146 @@
// ModelListView.swift // ModelListView.swift
// MNNLLMiOS // MNNLLMiOS
// //
// Created by () on 2025/1/3. // Created by () on 2025/7/4.
// //
import SwiftUI import SwiftUI
struct ModelListView: View { struct ModelListView: View {
@ObservedObject var viewModel: ModelListViewModel @ObservedObject var viewModel: ModelListViewModel
@State private var searchText = ""
@State private var scrollOffset: CGFloat = 0
@State private var showHelp = false
@State private var showUserGuide = false
@State private var downloadSources: ModelSource?
@State private var selectedSource = ModelSourceManager.shared.selectedSource @State private var selectedSource = ModelSourceManager.shared.selectedSource
@State private var showSourceMenu = false
@State private var showOptions = false @State private var selectedTags: Set<String> = []
@State private var buttonFrame: CGRect = .zero @State private var selectedCategories: Set<String> = []
@State private var selectedVendors: Set<String> = []
@State private var showFilterMenu = false
var body: some View { var body: some View {
ZStack { ScrollView {
VStack { LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
HStack { Section {
Button { modelListSection
showOptions.toggle() } header: {
} label: { toolbarSection
HStack {
Text("下载源:")
.font(.system(size: 12, weight: .regular))
.foregroundColor(showOptions ? .primaryBlue : .black )
Text(selectedSource.rawValue)
.font(.system(size: 12, weight: .regular))
.foregroundColor(showOptions ? .primaryBlue : .black )
Image(systemName: "chevron.down")
.frame(width: 10, height: 10, alignment: .leading)
.scaledToFit()
.foregroundColor(showOptions ? .primaryBlue : .black )
}
.padding(.leading)
}
Spacer()
} }
.frame(maxWidth: .infinity, maxHeight: 20)
.background(
GeometryReader { geometry in
Color.white.onAppear {
buttonFrame = geometry.frame(in: .global)
}
}
)
List {
SearchBar(text: $viewModel.searchText)
.listRowInsets(EdgeInsets())
.listRowSeparator(.hidden)
.padding(.horizontal)
ForEach(viewModel.filteredModels, id: \.modelId) { model in
ModelRowView(model: model,
viewModel: viewModel,
downloadProgress: viewModel.downloadProgress[model.modelId] ?? 0,
isDownloading: viewModel.currentlyDownloading == model.modelId,
isOtherDownloading: viewModel.currentlyDownloading != nil) {
if model.isDownloaded {
viewModel.selectModel(model)
} else {
Task {
await viewModel.downloadModel(model)
}
}
}
.listRowSeparator(.hidden)
.listRowBackground(viewModel.pinnedModelIds.contains(model.modelId) ? Color.black.opacity(0.05) : Color.clear)
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
SwipeActionsView(model: model, viewModel: viewModel)
}
}
}
.listStyle(.plain)
.sheet(isPresented: $showHelp) {
HelpView()
}
.refreshable {
await viewModel.fetchModels()
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK", role: .cancel) {}
} message: {
Text(viewModel.errorMessage)
}
.onAppear {
checkFirstLaunch()
}
.alert(isPresented: $showUserGuide) {
Alert(
title: Text("User Guide"),
message: Text("""
This is a local large model application that requires certain performance from your device.
It is recommended to choose different model sizes based on your device's memory.
The model recommendations for iPhone are as follows:
- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).
- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).
- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).
Choosing a model that is too large may cause insufficient memory and crashes.
"""),
dismissButton: .default(Text("OK"))
)
}
Spacer()
} }
}
.searchable(text: $searchText, prompt: "Search models...")
.onChange(of: searchText) { _, newValue in
viewModel.searchText = newValue
}
.refreshable {
await viewModel.fetchModels()
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK") { }
} message: {
Text(viewModel.errorMessage)
}
}
// Extract model list section as independent view
@ViewBuilder
private var modelListSection: some View {
LazyVStack(spacing: 8) {
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
modelRowView(model: model, index: index)
if index < filteredModels.count - 1 {
Divider()
.padding(.horizontal, 16)
}
}
}
.padding(.vertical, 8)
}
// Extract toolbar section as independent view
@ViewBuilder
private var toolbarSection: some View {
ToolbarView(
viewModel: viewModel, selectedSource: $selectedSource,
showSourceMenu: $showSourceMenu,
selectedTags: $selectedTags,
selectedCategories: $selectedCategories,
selectedVendors: $selectedVendors,
quickFilterTags: viewModel.quickFilterTags,
showFilterMenu: $showFilterMenu,
onSourceChange: handleSourceChange
)
}
// Extract single model row view as independent method
@ViewBuilder
private func modelRowView(model: ModelInfo, index: Int) -> some View {
ModelRowView(
model: model,
viewModel: viewModel,
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
isDownloading: viewModel.currentlyDownloading == model.id,
isOtherDownloading: isOtherDownloadingCheck(model: model)
) {
Task {
await viewModel.downloadModel(model)
}
}
.padding(.horizontal, 16)
}
// Extract complex boolean logic as independent method
private func isOtherDownloadingCheck(model: ModelInfo) -> Bool {
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
}
// Extract source change handling logic as independent method
private func handleSourceChange(_ source: ModelSource) {
ModelSourceManager.shared.updateSelectedSource(source)
selectedSource = source
Task {
await viewModel.fetchModels()
}
}
// Filter models based on selected tags, categories and vendors
private var filteredModels: [ModelInfo] {
let baseFiltered = viewModel.filteredModels
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
return baseFiltered
}
return baseFiltered.filter { model in
let tagMatch = checkTagMatch(model: model)
let categoryMatch = checkCategoryMatch(model: model)
let vendorMatch = checkVendorMatch(model: model)
if showOptions { return tagMatch && categoryMatch && vendorMatch
CustomPopupMenu(isPresented: $showOptions, }
selectedSource: $selectedSource, }
anchorFrame: buttonFrame)
// Extract tag matching logic as independent method
private func checkTagMatch(model: ModelInfo) -> Bool {
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
model.localizedTags.contains { tag in
tag.localizedCaseInsensitiveContains(selectedTag)
} }
} }
} }
private func checkFirstLaunch() { // Extract category matching logic as independent method
let hasLaunchedBefore = UserDefaults.standard.bool(forKey: "hasLaunchedBefore") private func checkCategoryMatch(model: ModelInfo) -> Bool {
if !hasLaunchedBefore { return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
// Show the user guide alert model.categories?.contains { category in
showUserGuide = true category.localizedCaseInsensitiveContains(selectedCategory)
// Set the flag to true so it doesn't show again } ?? false
UserDefaults.standard.set(true, forKey: "hasLaunchedBefore") }
}
// Extract vendor matching logic as independent method
private func checkVendorMatch(model: ModelInfo) -> Bool {
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
} }
} }
} }

View File

@ -9,8 +9,8 @@ import SwiftUI
// MARK: - // MARK: -
struct ActionButtonsView: View { struct ActionButtonsView: View {
let model: TBModelInfo let model: ModelInfo
@ObservedObject var viewModel: TBModelListViewModel @ObservedObject var viewModel: ModelListViewModel
let downloadProgress: Double let downloadProgress: Double
let isDownloading: Bool let isDownloading: Bool
let isOtherDownloading: Bool let isOtherDownloading: Bool
@ -40,4 +40,4 @@ struct ActionButtonsView: View {
} }
.frame(width: 60) .frame(width: 60)
} }
} }

View File

@ -9,7 +9,7 @@ import SwiftUI
// MARK: - // MARK: -
struct DownloadingButtonView: View { struct DownloadingButtonView: View {
@ObservedObject var viewModel: TBModelListViewModel @ObservedObject var viewModel: ModelListViewModel
let downloadProgress: Double let downloadProgress: Double
var body: some View { var body: some View {
@ -29,4 +29,4 @@ struct DownloadingButtonView: View {
} }
} }
} }
} }

View File

@ -2,7 +2,7 @@
// ModelRowView.swift // ModelRowView.swift
// MNNLLMiOS // MNNLLMiOS
// //
// Created by () on 2025/1/3. // Created by () on 2025/7/4.
// //
import SwiftUI import SwiftUI
@ -17,127 +17,78 @@ struct ModelRowView: View {
let isOtherDownloading: Bool let isOtherDownloading: Bool
let onDownload: () -> Void let onDownload: () -> Void
@State private var showDeleteAlert = false @State private var showDeleteAlert = false
//
private var localizedTags: [String] {
model.localizedTags
}
//
private var formattedSize: String {
model.formattedSize
}
var body: some View { var body: some View {
HStack(alignment: .top) { HStack(alignment: .top, spacing: 0) {
ModelIconView(modelId: model.modelId) //
ModelIconView(modelId: model.id)
.frame(width: 40, height: 40) .frame(width: 40, height: 40)
VStack(alignment: .leading, spacing: 5) { //
Text(model.name) VStack(alignment: .leading, spacing: 6) {
//
Text(model.modelName)
.font(.headline) .font(.headline)
.fontWeight(.semibold) .fontWeight(.semibold)
.lineLimit(1) .lineLimit(1)
if let lastUsedAt = model.lastUsedAt { //
Text("Last used: \(lastUsedAt.formatAgo())") if !localizedTags.isEmpty {
.font(.caption2) TagsView(tags: localizedTags)
.foregroundColor(.gray)
}
if !model.tags.isEmpty {
ScrollView(.horizontal, showsIndicators: false) {
HStack {
ForEach(model.tags, id: \.self) { tag in
Text(tag)
.fontWeight(.regular)
.font(.caption)
.foregroundColor(.gray)
.padding(.horizontal, 8)
.padding(.vertical, 3)
.background(
RoundedRectangle(cornerRadius: 8)
.stroke(Color.gray.opacity(0.5), lineWidth: 0.5)
)
}
}
}
.frame(height: 25)
} }
} }
.padding(.leading, 8)
Spacer() Spacer()
VStack(alignment: .center, spacing: 4) { ActionButtonsView(
if model.isDownloaded { model: model,
Button(action: { viewModel: viewModel,
showDeleteAlert = true downloadProgress: downloadProgress,
}) { isDownloading: isDownloading,
Image(systemName: "trash") isOtherDownloading: isOtherDownloading,
.fontWeight(.regular) formattedSize: formattedSize,
.foregroundColor(.black.opacity(0.8)) onDownload: onDownload,
.frame(width: 20, height: 20) showDeleteAlert: $showDeleteAlert
)
Text("已下载")
.font(.caption2)
.foregroundColor(.gray)
.padding(.top, 4)
}
} else {
if isDownloading {
Button(action: {
Task {
await viewModel.cancelDownload()
}
}) {
ProgressView(value: downloadProgress)
.progressViewStyle(CircularProgressViewStyle())
.frame(width: 28, height: 28)
Text(String(format: "%.2f%%", downloadProgress * 100))
.font(.caption2)
.foregroundColor(.gray)
}
} else {
Button(action: onDownload) {
Image(systemName: "arrow.down.circle.fill")
.font(.title2)
}
.foregroundColor(isOtherDownloading ? .gray : .primaryPurple)
.disabled(isOtherDownloading)
HStack(alignment: .bottom, spacing: 2) {
Image(systemName: "folder")
.font(.caption2)
Text(model.formattedSize)
.font(.caption2)
.lineLimit(1)
.minimumScaleFactor(0.8)
.offset(y: 1)
.onAppear {
if !model.isDownloaded && model.cachedSize == nil {
Task {
if let size = await model.fetchRemoteSize() {
await MainActor.run {
if let index = viewModel.models.firstIndex(where: { $0.modelId == model.modelId }) {
viewModel.models[index].cachedSize = size
}
}
}
}
}
}
}
.foregroundColor(.gray)
}
}
}
.frame(width: 60)
} }
.padding(.vertical, 8) .padding(.vertical, 8)
.alert(isPresented: $showDeleteAlert) { .contentShape(Rectangle()) //
Alert( .onTapGesture {
title: Text("确认删除"), handleRowTap()
message: Text("是否确认删除该模型?"), }
primaryButton: .destructive(Text("删除")) { .alert("确认删除", isPresented: $showDeleteAlert) {
Task { Button("删除", role: .destructive) {
await viewModel.deleteModel(model) Task {
} await viewModel.deleteModel(model)
}, }
secondaryButton: .cancel(Text("取消")) }
) Button("取消", role: .cancel) { }
} message: {
Text("是否确认删除该模型?")
}
}
private func handleRowTap() {
if model.isDownloaded {
return
} else if isDownloading {
Task {
await viewModel.cancelDownload()
}
} else if !isOtherDownloading {
onDownload()
} }
} }
} }

View File

@ -13,7 +13,7 @@ struct SwipeActionsView: View {
@ObservedObject var viewModel: ModelListViewModel @ObservedObject var viewModel: ModelListViewModel
var body: some View { var body: some View {
if viewModel.pinnedModelIds.contains(model.modelId) { if viewModel.pinnedModelIds.contains(model.id) {
Button { Button {
viewModel.unpinModel(model) viewModel.unpinModel(model)
} label: { } label: {

View File

@ -1,147 +0,0 @@
//
// TBModelListView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct TBModelListView: View {
@ObservedObject var viewModel: TBModelListViewModel
@State private var searchText = ""
@State private var selectedSource = ModelSourceManager.shared.selectedSource
@State private var showSourceMenu = false
@State private var selectedTags: Set<String> = []
@State private var selectedCategories: Set<String> = []
@State private var selectedVendors: Set<String> = []
@State private var showFilterMenu = false
var body: some View {
ScrollView {
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
Section {
modelListSection
} header: {
toolbarSection
}
}
}
.searchable(text: $searchText, prompt: "Search models...")
.onChange(of: searchText) { _, newValue in
viewModel.searchText = newValue
}
.refreshable {
await viewModel.fetchModels()
}
.alert("Error", isPresented: $viewModel.showError) {
Button("OK") { }
} message: {
Text(viewModel.errorMessage)
}
}
// Extract model list section as independent view
@ViewBuilder
private var modelListSection: some View {
LazyVStack(spacing: 8) {
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
modelRowView(model: model, index: index)
if index < filteredModels.count - 1 {
Divider()
.padding(.horizontal, 16)
}
}
}
.padding(.vertical, 8)
}
// Extract toolbar section as independent view
@ViewBuilder
private var toolbarSection: some View {
ToolbarView(
viewModel: viewModel, selectedSource: $selectedSource,
showSourceMenu: $showSourceMenu,
selectedTags: $selectedTags,
selectedCategories: $selectedCategories,
selectedVendors: $selectedVendors,
quickFilterTags: viewModel.quickFilterTags,
showFilterMenu: $showFilterMenu,
onSourceChange: handleSourceChange
)
}
// Extract single model row view as independent method
@ViewBuilder
private func modelRowView(model: TBModelInfo, index: Int) -> some View {
TBModelRowView(
model: model,
viewModel: viewModel,
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
isDownloading: viewModel.currentlyDownloading == model.id,
isOtherDownloading: isOtherDownloadingCheck(model: model)
) {
Task {
await viewModel.downloadModel(model)
}
}
.padding(.horizontal, 16)
}
// Extract complex boolean logic as independent method
private func isOtherDownloadingCheck(model: TBModelInfo) -> Bool {
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
}
// Extract source change handling logic as independent method
private func handleSourceChange(_ source: ModelSource) {
ModelSourceManager.shared.updateSelectedSource(source)
selectedSource = source
Task {
await viewModel.fetchModels()
}
}
// Filter models based on selected tags, categories and vendors
private var filteredModels: [TBModelInfo] {
let baseFiltered = viewModel.filteredModels
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
return baseFiltered
}
return baseFiltered.filter { model in
let tagMatch = checkTagMatch(model: model)
let categoryMatch = checkCategoryMatch(model: model)
let vendorMatch = checkVendorMatch(model: model)
return tagMatch && categoryMatch && vendorMatch
}
}
// Extract tag matching logic as independent method
private func checkTagMatch(model: TBModelInfo) -> Bool {
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
model.localizedTags.contains { tag in
tag.localizedCaseInsensitiveContains(selectedTag)
}
}
}
// Extract category matching logic as independent method
private func checkCategoryMatch(model: TBModelInfo) -> Bool {
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
model.categories?.contains { category in
category.localizedCaseInsensitiveContains(selectedCategory)
} ?? false
}
}
// Extract vendor matching logic as independent method
private func checkVendorMatch(model: TBModelInfo) -> Bool {
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
}
}
}

View File

@ -1,94 +0,0 @@
//
// TBModelRowView.swift
// MNNLLMiOS
//
// Created by () on 2025/7/4.
//
import SwiftUI
struct TBModelRowView: View {
let model: TBModelInfo
@ObservedObject var viewModel: TBModelListViewModel
let downloadProgress: Double
let isDownloading: Bool
let isOtherDownloading: Bool
let onDownload: () -> Void
@State private var showDeleteAlert = false
//
private var localizedTags: [String] {
model.localizedTags
}
//
private var formattedSize: String {
model.formattedSize
}
var body: some View {
HStack(alignment: .top, spacing: 0) {
//
ModelIconView(modelId: model.id)
.frame(width: 40, height: 40)
//
VStack(alignment: .leading, spacing: 6) {
//
Text(model.modelName)
.font(.headline)
.fontWeight(.semibold)
.lineLimit(1)
//
if !localizedTags.isEmpty {
TagsView(tags: localizedTags)
}
}
.padding(.leading, 8)
Spacer()
ActionButtonsView(
model: model,
viewModel: viewModel,
downloadProgress: downloadProgress,
isDownloading: isDownloading,
isOtherDownloading: isOtherDownloading,
formattedSize: formattedSize,
onDownload: onDownload,
showDeleteAlert: $showDeleteAlert
)
}
.padding(.vertical, 8)
.contentShape(Rectangle()) //
.onTapGesture {
handleRowTap()
}
.alert("确认删除", isPresented: $showDeleteAlert) {
Button("删除", role: .destructive) {
Task {
await viewModel.deleteModel(model)
}
}
Button("取消", role: .cancel) { }
} message: {
Text("是否确认删除该模型?")
}
}
private func handleRowTap() {
if model.isDownloaded {
return
} else if isDownloading {
Task {
await viewModel.cancelDownload()
}
} else if !isOtherDownloading {
onDownload()
}
}
}

View File

@ -19,6 +19,9 @@
} }
} }
} }
},
"Audio Message" : {
}, },
"Benchmark" : { "Benchmark" : {
@ -202,9 +205,6 @@
} }
} }
} }
},
"Last used: %@" : {
}, },
"Model Configuration" : { "Model Configuration" : {
"localizations" : { "localizations" : {
@ -365,11 +365,9 @@
} }
} }
} }
},
"TB模型" : {
}, },
"This is a local large model application that requires certain performance from your device.\nIt is recommended to choose different model sizes based on your device's memory. \n\nThe model recommendations for iPhone are as follows:\n- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).\n- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).\n- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).\n\nChoosing a model that is too large may cause insufficient memory and crashes." : { "This is a local large model application that requires certain performance from your device.\nIt is recommended to choose different model sizes based on your device's memory. \n\nThe model recommendations for iPhone are as follows:\n- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).\n- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).\n- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).\n\nChoosing a model that is too large may cause insufficient memory and crashes." : {
"extractionState" : "stale",
"localizations" : { "localizations" : {
"en" : { "en" : {
"stringUnit" : { "stringUnit" : {
@ -426,6 +424,7 @@
} }
}, },
"User Guide" : { "User Guide" : {
"extractionState" : "stale",
"localizations" : { "localizations" : {
"zh-Hans" : { "zh-Hans" : {
"stringUnit" : { "stringUnit" : {
@ -525,9 +524,6 @@
}, },
"语言" : { "语言" : {
},
"错误" : {
} }
}, },
"version" : "1.0" "version" : "1.0"