mirror of https://github.com/alibaba/MNN.git
[feat] change data source
This commit is contained in:
parent
3d2091bc24
commit
85855f7dbc
|
@ -102,7 +102,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
|
||||||
PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
|
PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
|
||||||
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
|
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
|
||||||
|
|
||||||
if let isDeepSeek = self?.modelInfo.name.lowercased().contains("deepseek"), isDeepSeek == true,
|
if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
|
||||||
let text = self?.processor.process(progress: message.text) {
|
let text = self?.processor.process(progress: message.text) {
|
||||||
updateLastMsg?.text = text
|
updateLastMsg?.text = text
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -25,13 +25,13 @@ final class LLMChatData {
|
||||||
|
|
||||||
self.assistant = LLMChatUser(
|
self.assistant = LLMChatUser(
|
||||||
uid: "2",
|
uid: "2",
|
||||||
name: modelInfo.name,
|
name: modelInfo.modelName,
|
||||||
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
||||||
)
|
)
|
||||||
|
|
||||||
self.system = LLMChatUser(
|
self.system = LLMChatUser(
|
||||||
uid: "0",
|
uid: "0",
|
||||||
name: modelInfo.name,
|
name: modelInfo.modelName,
|
||||||
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
avatar: AssetExtractor.createLocalUrl(forImageNamed: icon, withExtension: "png")
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
let modelConfigManager: ModelConfigManager
|
let modelConfigManager: ModelConfigManager
|
||||||
|
|
||||||
var isDiffusionModel: Bool {
|
var isDiffusionModel: Bool {
|
||||||
return modelInfo.name.lowercased().contains("diffusion")
|
return modelInfo.modelName.lowercased().contains("diffusion")
|
||||||
}
|
}
|
||||||
|
|
||||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||||
|
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
), userType: .system)
|
), userType: .system)
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelInfo.name.lowercased().contains("diffusion") {
|
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||||
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
||||||
Task { @MainActor in
|
Task { @MainActor in
|
||||||
print("Diffusion Model \(success)")
|
print("Diffusion Model \(success)")
|
||||||
|
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
func sendToLLM(draft: DraftMessage) {
|
func sendToLLM(draft: DraftMessage) {
|
||||||
self.send(draft: draft, userType: .user)
|
self.send(draft: draft, userType: .user)
|
||||||
if isModelLoaded {
|
if isModelLoaded {
|
||||||
if modelInfo.name.lowercased().contains("diffusion") {
|
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||||
self.getDiffusionResponse(draft: draft)
|
self.getDiffusionResponse(draft: draft)
|
||||||
} else {
|
} else {
|
||||||
self.getLLMRespsonse(draft: draft)
|
self.getLLMRespsonse(draft: draft)
|
||||||
|
@ -284,7 +284,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
private func convertDeepSeekMutliChat(content: String) -> String {
|
private func convertDeepSeekMutliChat(content: String) -> String {
|
||||||
if self.modelInfo.name.lowercased().contains("deepseek") {
|
if self.modelInfo.modelName.lowercased().contains("deepseek") {
|
||||||
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||||
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||||
*/
|
*/
|
||||||
|
@ -337,7 +337,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
ChatHistoryManager.shared.saveChat(
|
ChatHistoryManager.shared.saveChat(
|
||||||
historyId: historyId,
|
historyId: historyId,
|
||||||
modelId: modelInfo.modelId,
|
modelId: modelInfo.modelId,
|
||||||
modelName: modelInfo.name,
|
modelName: modelInfo.modelName,
|
||||||
messages: messages
|
messages: messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -387,4 +387,4 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
print("Error accessing tmp directory: \(error.localizedDescription)")
|
print("Error accessing tmp directory: \(error.localizedDescription)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
let modelConfigManager: ModelConfigManager
|
let modelConfigManager: ModelConfigManager
|
||||||
|
|
||||||
var isDiffusionModel: Bool {
|
var isDiffusionModel: Bool {
|
||||||
return modelInfo.name.lowercased().contains("diffusion")
|
return modelInfo.modelName.lowercased().contains("diffusion")
|
||||||
}
|
}
|
||||||
|
|
||||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||||
|
@ -88,7 +88,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
), userType: .system)
|
), userType: .system)
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelInfo.name.lowercased().contains("diffusion") {
|
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||||
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
diffusion = DiffusionSession(modelPath: modelPath, completion: { [weak self] success in
|
||||||
Task { @MainActor in
|
Task { @MainActor in
|
||||||
print("Diffusion Model \(success)")
|
print("Diffusion Model \(success)")
|
||||||
|
@ -150,7 +150,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
func sendToLLM(draft: DraftMessage) {
|
func sendToLLM(draft: DraftMessage) {
|
||||||
self.send(draft: draft, userType: .user)
|
self.send(draft: draft, userType: .user)
|
||||||
if isModelLoaded {
|
if isModelLoaded {
|
||||||
if modelInfo.name.lowercased().contains("diffusion") {
|
if modelInfo.modelName.lowercased().contains("diffusion") {
|
||||||
self.getDiffusionResponse(draft: draft)
|
self.getDiffusionResponse(draft: draft)
|
||||||
} else {
|
} else {
|
||||||
self.getLLMRespsonse(draft: draft)
|
self.getLLMRespsonse(draft: draft)
|
||||||
|
@ -298,7 +298,7 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
private func convertDeepSeekMutliChat(content: String) -> String {
|
private func convertDeepSeekMutliChat(content: String) -> String {
|
||||||
if self.modelInfo.name.lowercased().contains("deepseek") {
|
if self.modelInfo.modelName.lowercased().contains("deepseek") {
|
||||||
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
/* formate:: <|begin_of_sentence|><|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||||
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
<|User|>{text}<|Assistant|>{text}<|end_of_sentence|>
|
||||||
*/
|
*/
|
||||||
|
@ -350,8 +350,8 @@ final class LLMChatViewModel: ObservableObject {
|
||||||
func onStop() {
|
func onStop() {
|
||||||
ChatHistoryManager.shared.saveChat(
|
ChatHistoryManager.shared.saveChat(
|
||||||
historyId: historyId,
|
historyId: historyId,
|
||||||
modelId: modelInfo.modelId,
|
modelId: modelInfo.id,
|
||||||
modelName: modelInfo.name,
|
modelName: modelInfo.modelName,
|
||||||
messages: messages
|
messages: messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ struct LLMChatView: View {
|
||||||
@State private var showSettings = false
|
@State private var showSettings = false
|
||||||
|
|
||||||
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
init(modelInfo: ModelInfo, history: ChatHistory? = nil) {
|
||||||
self.title = modelInfo.name
|
self.title = modelInfo.modelName
|
||||||
self.modelPath = modelInfo.localPath
|
self.modelPath = modelInfo.localPath
|
||||||
let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history)
|
let viewModel = LLMChatViewModel(modelInfo: modelInfo, history: history)
|
||||||
_viewModel = StateObject(wrappedValue: viewModel)
|
_viewModel = StateObject(wrappedValue: viewModel)
|
||||||
|
|
|
@ -12,14 +12,14 @@ struct LocalModelListView: View {
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
List {
|
List {
|
||||||
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.modelId) { model in
|
ForEach(viewModel.filteredModels.filter { $0.isDownloaded }, id: \.id) { model in
|
||||||
Button(action: {
|
Button(action: {
|
||||||
viewModel.selectModel(model)
|
viewModel.selectModel(model)
|
||||||
}) {
|
}) {
|
||||||
LocalModelRowView(model: model)
|
LocalModelRowView(model: model)
|
||||||
}
|
}
|
||||||
.listRowSeparator(.hidden)
|
.listRowSeparator(.hidden)
|
||||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.modelId) ? Color.black.opacity(0.05) : Color.clear)
|
.listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
|
||||||
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
||||||
SwipeActionsView(model: model, viewModel: viewModel)
|
SwipeActionsView(model: model, viewModel: viewModel)
|
||||||
}
|
}
|
||||||
|
@ -35,4 +35,4 @@ struct LocalModelListView: View {
|
||||||
Text(viewModel.errorMessage)
|
Text(viewModel.errorMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,36 +11,28 @@ struct LocalModelRowView: View {
|
||||||
|
|
||||||
let model: ModelInfo
|
let model: ModelInfo
|
||||||
|
|
||||||
|
private var localizedTags: [String] {
|
||||||
|
model.localizedTags
|
||||||
|
}
|
||||||
|
|
||||||
|
private var formattedSize: String {
|
||||||
|
model.formattedSize
|
||||||
|
}
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
HStack(alignment: .center) {
|
HStack(alignment: .center) {
|
||||||
|
|
||||||
ModelIconView(modelId: model.modelId)
|
ModelIconView(modelId: model.id)
|
||||||
.frame(width: 50, height: 50)
|
.frame(width: 50, height: 50)
|
||||||
|
|
||||||
VStack(alignment: .leading, spacing: 8) {
|
VStack(alignment: .leading, spacing: 8) {
|
||||||
Text(model.name)
|
Text(model.modelName)
|
||||||
.font(.headline)
|
.font(.headline)
|
||||||
.fontWeight(.semibold)
|
.fontWeight(.semibold)
|
||||||
.lineLimit(1)
|
.lineLimit(1)
|
||||||
|
|
||||||
if !model.tags.isEmpty {
|
if !localizedTags.isEmpty {
|
||||||
ScrollView(.horizontal, showsIndicators: false) {
|
TagsView(tags: localizedTags)
|
||||||
HStack {
|
|
||||||
ForEach(model.tags, id: \.self) { tag in
|
|
||||||
Text(tag)
|
|
||||||
.fontWeight(.regular)
|
|
||||||
.font(.caption)
|
|
||||||
.foregroundColor(Color(red: 151/255, green: 151/255, blue: 151/255))
|
|
||||||
.padding(.horizontal, 10)
|
|
||||||
.padding(.vertical, 4)
|
|
||||||
.background(
|
|
||||||
RoundedRectangle(cornerRadius: 10)
|
|
||||||
.stroke(Color(red: 151/255, green: 151/255, blue: 151/255), lineWidth: 0.5)
|
|
||||||
.padding(1)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
HStack {
|
HStack {
|
||||||
|
@ -51,7 +43,7 @@ struct LocalModelRowView: View {
|
||||||
.foregroundColor(.gray)
|
.foregroundColor(.gray)
|
||||||
.frame(width: 20, height: 20)
|
.frame(width: 20, height: 20)
|
||||||
|
|
||||||
Text(model.formattedSize)
|
Text(formattedSize)
|
||||||
.font(.caption)
|
.font(.caption)
|
||||||
.fontWeight(.medium)
|
.fontWeight(.medium)
|
||||||
.foregroundColor(.gray)
|
.foregroundColor(.gray)
|
||||||
|
|
|
@ -16,39 +16,32 @@ struct MainTabView: View {
|
||||||
@State private var showWebView = false
|
@State private var showWebView = false
|
||||||
@State private var webViewURL: URL?
|
@State private var webViewURL: URL?
|
||||||
@State private var navigateToSettings = false
|
@State private var navigateToSettings = false
|
||||||
@StateObject private var modelListViewModel = TBModelListViewModel()
|
@StateObject private var modelListViewModel = ModelListViewModel()
|
||||||
@StateObject private var localModelListViewModel = ModelListViewModel()
|
|
||||||
@State private var selectedTab: Int = 0
|
@State private var selectedTab: Int = 0
|
||||||
@State private var titles = ["本地模型", "模型市场", "TB模型", "Benchmark"]
|
@State private var titles = ["本地模型", "模型市场", "Benchmark"]
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
ZStack {
|
ZStack {
|
||||||
NavigationView {
|
NavigationView {
|
||||||
TabView(selection: $selectedTab) {
|
TabView(selection: $selectedTab) {
|
||||||
LocalModelListView(viewModel: localModelListViewModel)
|
LocalModelListView(viewModel: modelListViewModel)
|
||||||
.tabItem {
|
.tabItem {
|
||||||
Image(systemName: "house.fill")
|
Image(systemName: "house.fill")
|
||||||
Text("本地模型")
|
Text("本地模型")
|
||||||
}
|
}
|
||||||
.tag(0)
|
.tag(0)
|
||||||
ModelListView(viewModel: localModelListViewModel)
|
ModelListView(viewModel: modelListViewModel)
|
||||||
.tabItem {
|
.tabItem {
|
||||||
Image(systemName: "cart.fill")
|
Image(systemName: "doc.text.fill")
|
||||||
Text("模型市场")
|
Text("模型市场")
|
||||||
}
|
}
|
||||||
.tag(1)
|
.tag(1)
|
||||||
TBModelListView(viewModel: modelListViewModel)
|
|
||||||
.tabItem {
|
|
||||||
Image(systemName: "doc.text.fill")
|
|
||||||
Text("TB模型")
|
|
||||||
}
|
|
||||||
.tag(2)
|
|
||||||
BenchmarkView()
|
BenchmarkView()
|
||||||
.tabItem {
|
.tabItem {
|
||||||
Image(systemName: "clock.fill")
|
Image(systemName: "clock.fill")
|
||||||
Text("Benchmark")
|
Text("Benchmark")
|
||||||
}
|
}
|
||||||
.tag(3)
|
.tag(2)
|
||||||
}
|
}
|
||||||
.background(
|
.background(
|
||||||
ZStack {
|
ZStack {
|
||||||
|
@ -130,19 +123,13 @@ struct MainTabView: View {
|
||||||
|
|
||||||
@ViewBuilder
|
@ViewBuilder
|
||||||
private var chatDestination: some View {
|
private var chatDestination: some View {
|
||||||
if let model = localModelListViewModel.selectedModel {
|
if let model = modelListViewModel.selectedModel {
|
||||||
LLMChatView(modelInfo: model)
|
LLMChatView(modelInfo: model)
|
||||||
.navigationBarHidden(false)
|
.navigationBarHidden(false)
|
||||||
.navigationBarTitleDisplayMode(.inline)
|
.navigationBarTitleDisplayMode(.inline)
|
||||||
.toolbar(.hidden, for: .tabBar) // Hide tab bar in chat
|
.toolbar(.hidden, for: .tabBar) // Hide tab bar in chat
|
||||||
} else if let history = selectedHistory {
|
} else if let history = selectedHistory {
|
||||||
let modelInfo = ModelInfo(
|
let modelInfo = ModelInfo(modelId: history.modelId, isDownloaded: true)
|
||||||
modelId: history.modelId,
|
|
||||||
createdAt: "",
|
|
||||||
downloads: 0,
|
|
||||||
tags: [],
|
|
||||||
isDownloaded: true
|
|
||||||
)
|
|
||||||
LLMChatView(modelInfo: modelInfo, history: history)
|
LLMChatView(modelInfo: modelInfo, history: history)
|
||||||
.navigationBarHidden(false)
|
.navigationBarHidden(false)
|
||||||
.navigationBarTitleDisplayMode(.inline)
|
.navigationBarTitleDisplayMode(.inline)
|
||||||
|
@ -155,17 +142,17 @@ struct MainTabView: View {
|
||||||
private var chatIsActiveBinding: Binding<Bool> {
|
private var chatIsActiveBinding: Binding<Bool> {
|
||||||
Binding<Bool>(
|
Binding<Bool>(
|
||||||
get: {
|
get: {
|
||||||
return localModelListViewModel.selectedModel != nil || selectedHistory != nil
|
return modelListViewModel.selectedModel != nil || selectedHistory != nil
|
||||||
},
|
},
|
||||||
set: { isActive in
|
set: { isActive in
|
||||||
if !isActive {
|
if !isActive {
|
||||||
// Record usage when returning from chat
|
// Record usage when returning from chat
|
||||||
if let model = localModelListViewModel.selectedModel {
|
if let model = modelListViewModel.selectedModel {
|
||||||
localModelListViewModel.recordModelUsage(modelName: model.name)
|
modelListViewModel.recordModelUsage(modelName: model.modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear selections
|
// Clear selections
|
||||||
localModelListViewModel.selectedModel = nil
|
modelListViewModel.selectedModel = nil
|
||||||
selectedHistory = nil
|
selectedHistory = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,4 +182,4 @@ struct MainTabView: View {
|
||||||
UITabBar.appearance().standardAppearance = appearance
|
UITabBar.appearance().standardAppearance = appearance
|
||||||
UITabBar.appearance().scrollEdgeAppearance = appearance
|
UITabBar.appearance().scrollEdgeAppearance = appearance
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,83 +1,108 @@
|
||||||
//
|
//
|
||||||
// ModelClient.swift
|
// TBModelInfo.swift
|
||||||
// MNNLLMiOS
|
// MNNLLMiOS
|
||||||
//
|
//
|
||||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||||
//
|
//
|
||||||
|
|
||||||
import Hub
|
import Hub
|
||||||
import Foundation
|
import Foundation
|
||||||
|
|
||||||
struct ModelInfo: Codable {
|
struct ModelInfo: Codable {
|
||||||
let modelId: String
|
// MARK: - Properties
|
||||||
let createdAt: String
|
let modelName: String
|
||||||
let downloads: Int
|
|
||||||
let tags: [String]
|
let tags: [String]
|
||||||
|
let categories: [String]?
|
||||||
|
let size_gb: Double?
|
||||||
|
let vendor: String?
|
||||||
|
let sources: [String: String]?
|
||||||
|
let tagTranslations: [String: [String]]?
|
||||||
|
|
||||||
var name: String {
|
// Runtime properties
|
||||||
modelId.removingTaobaoPrefix()
|
|
||||||
}
|
|
||||||
|
|
||||||
var isDownloaded: Bool = false
|
var isDownloaded: Bool = false
|
||||||
var lastUsedAt: Date?
|
var lastUsedAt: Date?
|
||||||
|
|
||||||
var cachedSize: Int64? = nil
|
var cachedSize: Int64? = nil
|
||||||
|
|
||||||
var localPath: String {
|
// MARK: - Initialization
|
||||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelId)).path
|
|
||||||
|
init(modelName: String = "",
|
||||||
|
tags: [String] = [],
|
||||||
|
categories: [String]? = nil,
|
||||||
|
size_gb: Double? = nil,
|
||||||
|
vendor: String? = nil,
|
||||||
|
sources: [String: String]? = nil,
|
||||||
|
tagTranslations: [String: [String]]? = nil,
|
||||||
|
isDownloaded: Bool = false,
|
||||||
|
lastUsedAt: Date? = nil,
|
||||||
|
cachedSize: Int64? = nil) {
|
||||||
|
|
||||||
|
self.modelName = modelName
|
||||||
|
self.tags = tags
|
||||||
|
self.categories = categories
|
||||||
|
self.size_gb = size_gb
|
||||||
|
self.vendor = vendor
|
||||||
|
self.sources = sources
|
||||||
|
self.tagTranslations = tagTranslations
|
||||||
|
self.isDownloaded = isDownloaded
|
||||||
|
self.lastUsedAt = lastUsedAt
|
||||||
|
self.cachedSize = cachedSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
init(modelId: String, isDownloaded: Bool = true) {
|
||||||
|
let modelName = modelId.components(separatedBy: "/").last ?? modelId
|
||||||
|
|
||||||
|
self.init(
|
||||||
|
modelName: modelName,
|
||||||
|
tags: [],
|
||||||
|
sources: ["huggingface": modelId],
|
||||||
|
isDownloaded: isDownloaded
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Identity & Localization
|
||||||
|
|
||||||
|
var id: String {
|
||||||
|
guard let sources = sources else {
|
||||||
|
return "taobao-mnn/\(modelName)"
|
||||||
|
}
|
||||||
|
|
||||||
|
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
|
||||||
|
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
||||||
|
}
|
||||||
|
|
||||||
|
var localizedTags: [String] {
|
||||||
|
let currentLanguage = LanguageManager.shared.currentLanguage
|
||||||
|
let isChineseLanguage = currentLanguage == "简体中文"
|
||||||
|
|
||||||
|
if isChineseLanguage, let translations = tagTranslations {
|
||||||
|
let languageCode = "zh-Hans"
|
||||||
|
return translations[languageCode] ?? tags
|
||||||
|
} else {
|
||||||
|
return tags
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - File System & Path Management
|
||||||
|
|
||||||
|
var localPath: String {
|
||||||
|
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||||
|
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelScopeId)).path
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Size Calculation & Formatting
|
||||||
|
|
||||||
var formattedSize: String {
|
var formattedSize: String {
|
||||||
if isDownloaded {
|
if let cached = cachedSize {
|
||||||
return formatLocalSize()
|
|
||||||
} else if let cached = cachedSize {
|
|
||||||
return formatBytes(cached)
|
return formatBytes(cached)
|
||||||
|
} else if isDownloaded {
|
||||||
|
return formatLocalSize()
|
||||||
|
} else if let sizeGb = size_gb {
|
||||||
|
return String(format: "%.1f GB", sizeGb)
|
||||||
} else {
|
} else {
|
||||||
return "计算中..."
|
return "Calculating..."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchRemoteSize() async -> Int64? {
|
|
||||||
let modelScopeId = modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
|
|
||||||
|
|
||||||
do {
|
|
||||||
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
|
|
||||||
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
|
|
||||||
return totalSize
|
|
||||||
} catch {
|
|
||||||
print("Error fetching remote size for \(modelId): \(error)")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func formatLocalSize() -> String {
|
|
||||||
let path = localPath
|
|
||||||
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
|
|
||||||
|
|
||||||
do {
|
|
||||||
let totalSize = try calculateDirectorySize(at: path)
|
|
||||||
return formatBytes(totalSize)
|
|
||||||
} catch {
|
|
||||||
return "未知"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
|
||||||
let fileManager = FileManager.default
|
|
||||||
var totalSize: Int64 = 0
|
|
||||||
|
|
||||||
let enumerator = fileManager.enumerator(atPath: path)
|
|
||||||
while let fileName = enumerator?.nextObject() as? String {
|
|
||||||
let filePath = (path as NSString).appendingPathComponent(fileName)
|
|
||||||
let attributes = try fileManager.attributesOfItem(atPath: filePath)
|
|
||||||
if let fileSize = attributes[.size] as? Int64 {
|
|
||||||
totalSize += fileSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalSize
|
|
||||||
}
|
|
||||||
|
|
||||||
private func formatBytes(_ bytes: Int64) -> String {
|
private func formatBytes(_ bytes: Int64) -> String {
|
||||||
let formatter = ByteCountFormatter()
|
let formatter = ByteCountFormatter()
|
||||||
formatter.allowedUnits = [.useGB]
|
formatter.allowedUnits = [.useGB]
|
||||||
|
@ -85,7 +110,102 @@ struct ModelInfo: Codable {
|
||||||
return formatter.string(fromByteCount: bytes)
|
return formatter.string(fromByteCount: bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - 云端文件大小计算方法
|
// MARK: - Local Size Calculation
|
||||||
|
|
||||||
|
private func formatLocalSize() -> String {
|
||||||
|
let path = localPath
|
||||||
|
guard FileManager.default.fileExists(atPath: path) else { return "Unknown" }
|
||||||
|
|
||||||
|
do {
|
||||||
|
let totalSize = try calculateDirectorySize(at: path)
|
||||||
|
return formatBytes(totalSize)
|
||||||
|
} catch {
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
||||||
|
let fileManager = FileManager.default
|
||||||
|
var totalSize: Int64 = 0
|
||||||
|
|
||||||
|
print("Calculating directory size for path: \(path)")
|
||||||
|
|
||||||
|
let directoryURL = URL(fileURLWithPath: path)
|
||||||
|
|
||||||
|
guard fileManager.fileExists(atPath: path) else {
|
||||||
|
print("Path does not exist: \(path)")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
let resourceKeys: [URLResourceKey] = [.isRegularFileKey, .totalFileAllocatedSizeKey, .fileSizeKey, .nameKey]
|
||||||
|
let enumerator = fileManager.enumerator(
|
||||||
|
at: directoryURL,
|
||||||
|
includingPropertiesForKeys: resourceKeys,
|
||||||
|
options: [.skipsHiddenFiles, .skipsPackageDescendants],
|
||||||
|
errorHandler: { (url, error) -> Bool in
|
||||||
|
print("Error accessing \(url): \(error)")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
guard let fileEnumerator = enumerator else {
|
||||||
|
throw NSError(domain: "FileEnumerationError", code: -1,
|
||||||
|
userInfo: [NSLocalizedDescriptionKey: "Failed to create file enumerator"])
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileCount = 0
|
||||||
|
for case let fileURL as URL in fileEnumerator {
|
||||||
|
do {
|
||||||
|
let resourceValues = try fileURL.resourceValues(forKeys: Set(resourceKeys))
|
||||||
|
|
||||||
|
guard let isRegularFile = resourceValues.isRegularFile, isRegularFile else { continue }
|
||||||
|
|
||||||
|
let fileName = resourceValues.name ?? "Unknown"
|
||||||
|
fileCount += 1
|
||||||
|
|
||||||
|
// Use actual disk allocated size, fallback to logical size if not available
|
||||||
|
if let actualSize = resourceValues.totalFileAllocatedSize {
|
||||||
|
totalSize += Int64(actualSize)
|
||||||
|
|
||||||
|
if fileCount <= 10 {
|
||||||
|
let actualSizeGB = Double(actualSize) / (1024 * 1024 * 1024)
|
||||||
|
let logicalSizeGB = Double(resourceValues.fileSize ?? 0) / (1024 * 1024 * 1024)
|
||||||
|
print("File \(fileCount): \(fileName) - Logical: \(String(format: "%.3f", logicalSizeGB)) GB, Actual: \(String(format: "%.3f", actualSizeGB)) GB")
|
||||||
|
}
|
||||||
|
} else if let logicalSize = resourceValues.fileSize {
|
||||||
|
totalSize += Int64(logicalSize)
|
||||||
|
|
||||||
|
if fileCount <= 10 {
|
||||||
|
let logicalSizeGB = Double(logicalSize) / (1024 * 1024 * 1024)
|
||||||
|
print("File \(fileCount): \(fileName) - Size: \(String(format: "%.3f", logicalSizeGB)) GB (fallback)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
print("Error getting resource values for \(fileURL): \(error)")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let totalSizeGB = Double(totalSize) / (1024 * 1024 * 1024)
|
||||||
|
print("Total files: \(fileCount), Total actual disk usage: \(String(format: "%.2f", totalSizeGB)) GB")
|
||||||
|
|
||||||
|
return totalSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Remote Size Calculation
|
||||||
|
|
||||||
|
func fetchRemoteSize() async -> Int64? {
|
||||||
|
let modelScopeId = "taobao-mnn/\(modelName)"
|
||||||
|
|
||||||
|
do {
|
||||||
|
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
|
||||||
|
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
|
||||||
|
return totalSize
|
||||||
|
} catch {
|
||||||
|
print("Error fetching remote size for \(id): \(error)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
||||||
let url = try buildURL(
|
let url = try buildURL(
|
||||||
|
@ -119,6 +239,8 @@ struct ModelInfo: Codable {
|
||||||
return totalSize
|
return totalSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Network Utilities
|
||||||
|
|
||||||
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
||||||
var components = URLComponents()
|
var components = URLComponents()
|
||||||
components.scheme = "https"
|
components.scheme = "https"
|
||||||
|
@ -139,21 +261,9 @@ struct ModelInfo: Codable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Codable
|
||||||
|
|
||||||
private enum CodingKeys: String, CodingKey {
|
private enum CodingKeys: String, CodingKey {
|
||||||
case modelId
|
case modelName, tags, categories, size_gb, vendor, sources, tagTranslations, cachedSize
|
||||||
case tags
|
|
||||||
case downloads
|
|
||||||
case createdAt
|
|
||||||
case cachedSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RepoInfo: Codable {
|
|
||||||
let modelId: String
|
|
||||||
let sha: String
|
|
||||||
let siblings: [Sibling]
|
|
||||||
|
|
||||||
struct Sibling: Codable {
|
|
||||||
let rfilename: String
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
// ModelListViewModel.swift
|
// ModelListViewModel.swift
|
||||||
// MNNLLMiOS
|
// MNNLLMiOS
|
||||||
//
|
//
|
||||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||||
//
|
//
|
||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
|
@ -10,74 +10,78 @@ import SwiftUI
|
||||||
|
|
||||||
@MainActor
|
@MainActor
|
||||||
class ModelListViewModel: ObservableObject {
|
class ModelListViewModel: ObservableObject {
|
||||||
|
// MARK: - Published Properties
|
||||||
@Published var models: [ModelInfo] = []
|
@Published var models: [ModelInfo] = []
|
||||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
@Published var searchText = ""
|
||||||
@Published private(set) var currentlyDownloading: String?
|
@Published var quickFilterTags: [String] = []
|
||||||
|
@Published var selectedModel: ModelInfo?
|
||||||
@Published var showError = false
|
@Published var showError = false
|
||||||
@Published var errorMessage = ""
|
@Published var errorMessage = ""
|
||||||
@Published var searchText = ""
|
|
||||||
|
|
||||||
@Published var selectedModel: ModelInfo?
|
// Download state
|
||||||
|
@Published private(set) var downloadProgress: [String: Double] = [:]
|
||||||
|
@Published private(set) var currentlyDownloading: String?
|
||||||
|
|
||||||
|
// MARK: - Private Properties
|
||||||
private let modelClient = ModelClient()
|
private let modelClient = ModelClient()
|
||||||
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
|
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
|
||||||
|
|
||||||
|
// MARK: - Model Data Access
|
||||||
|
|
||||||
public var pinnedModelIds: [String] {
|
public var pinnedModelIds: [String] {
|
||||||
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
|
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
|
||||||
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
|
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var allTags: [String] {
|
||||||
|
Array(Set(models.flatMap { $0.tags }))
|
||||||
|
}
|
||||||
|
|
||||||
|
var allCategories: [String] {
|
||||||
|
Array(Set(models.compactMap { $0.categories }.flatMap { $0 }))
|
||||||
|
}
|
||||||
|
|
||||||
|
var allVendors: [String] {
|
||||||
|
Array(Set(models.compactMap { $0.vendor }))
|
||||||
|
}
|
||||||
|
|
||||||
var filteredModels: [ModelInfo] {
|
var filteredModels: [ModelInfo] {
|
||||||
|
let filtered = searchText.isEmpty ? models : models.filter { model in
|
||||||
let filteredModels = searchText.isEmpty ? models : models.filter { model in
|
model.id.localizedCaseInsensitiveContains(searchText) ||
|
||||||
model.modelId.localizedCaseInsensitiveContains(searchText) ||
|
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
||||||
model.tags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
||||||
}
|
}
|
||||||
|
|
||||||
let downloadedModels = filteredModels.filter { $0.isDownloaded }
|
let downloaded = filtered.filter { $0.isDownloaded }
|
||||||
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded }
|
let notDownloaded = filtered.filter { !$0.isDownloaded }
|
||||||
|
|
||||||
return downloadedModels + notDownloadedModels
|
return downloaded + notDownloaded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
init() {
|
init() {
|
||||||
Task {
|
Task {
|
||||||
await fetchModels()
|
await fetchModels()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Data Management
|
||||||
|
|
||||||
func fetchModels() async {
|
func fetchModels() async {
|
||||||
do {
|
do {
|
||||||
var fetchedModels = try await modelClient.getModelList()
|
let info = try await modelClient.getModelInfo()
|
||||||
|
|
||||||
let hasDiffusionModels = fetchedModels.contains {
|
self.quickFilterTags = info.quickFilterTags ?? []
|
||||||
$0.name.lowercased().contains("diffusion")
|
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
|
||||||
}
|
|
||||||
|
|
||||||
if hasDiffusionModels {
|
var fetchedModels = info.models
|
||||||
fetchedModels = fetchedModels.filter { model in
|
|
||||||
let name = model.name.lowercased()
|
|
||||||
let tags = model.tags.map { $0.lowercased() }
|
|
||||||
|
|
||||||
// only show gpu diffusion
|
|
||||||
if name.contains("diffusion") {
|
|
||||||
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in 0..<fetchedModels.count {
|
filterDiffusionModels(fetchedModels: &fetchedModels)
|
||||||
let model = fetchedModels[i]
|
|
||||||
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.name)
|
|
||||||
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort models
|
|
||||||
sortModels(fetchedModels: &fetchedModels)
|
sortModels(fetchedModels: &fetchedModels)
|
||||||
|
self.models = fetchedModels
|
||||||
|
|
||||||
// 异步获取未下载模型的大小信息
|
// Asynchronously fetch size info for undownloaded models
|
||||||
Task {
|
Task {
|
||||||
await fetchModelSizes(for: fetchedModels)
|
await fetchModelSizes(for: fetchedModels)
|
||||||
}
|
}
|
||||||
|
@ -91,12 +95,12 @@ class ModelListViewModel: ObservableObject {
|
||||||
private func fetchModelSizes(for models: [ModelInfo]) async {
|
private func fetchModelSizes(for models: [ModelInfo]) async {
|
||||||
await withTaskGroup(of: Void.self) { group in
|
await withTaskGroup(of: Void.self) { group in
|
||||||
for (_, model) in models.enumerated() {
|
for (_, model) in models.enumerated() {
|
||||||
if !model.isDownloaded && model.cachedSize == nil {
|
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
|
||||||
group.addTask {
|
group.addTask {
|
||||||
if let size = await model.fetchRemoteSize() {
|
if let size = await model.fetchRemoteSize() {
|
||||||
await MainActor.run {
|
await MainActor.run {
|
||||||
// 查找当前模型在实际数组中的索引
|
// Find current model index in actual array
|
||||||
if let modelIndex = self.models.firstIndex(where: { $0.modelId == model.modelId }) {
|
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
|
||||||
self.models[modelIndex].cachedSize = size
|
self.models[modelIndex].cachedSize = size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,11 +111,29 @@ class ModelListViewModel: ObservableObject {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func recordModelUsage(modelName: String) {
|
private func filterDiffusionModels(fetchedModels: inout [ModelInfo]) {
|
||||||
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
let hasDiffusionModels = fetchedModels.contains {
|
||||||
if let index = models.firstIndex(where: { $0.name == modelName }) {
|
$0.modelName.lowercased().contains("diffusion")
|
||||||
models[index].lastUsedAt = Date()
|
}
|
||||||
sortModels(fetchedModels: &models)
|
|
||||||
|
if hasDiffusionModels {
|
||||||
|
fetchedModels = fetchedModels.filter { model in
|
||||||
|
let name = model.modelName.lowercased()
|
||||||
|
let tags = model.tags.map { $0.lowercased() }
|
||||||
|
|
||||||
|
// Only show GPU diffusion models
|
||||||
|
if name.contains("diffusion") {
|
||||||
|
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..<fetchedModels.count {
|
||||||
|
let model = fetchedModels[i]
|
||||||
|
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
|
||||||
|
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,24 +141,34 @@ class ModelListViewModel: ObservableObject {
|
||||||
let pinned = pinnedModelIds
|
let pinned = pinnedModelIds
|
||||||
|
|
||||||
fetchedModels.sort { (model1, model2) -> Bool in
|
fetchedModels.sort { (model1, model2) -> Bool in
|
||||||
let isPinned1 = pinned.contains(model1.modelId)
|
let isPinned1 = pinned.contains(model1.id)
|
||||||
let isPinned2 = pinned.contains(model2.modelId)
|
let isPinned2 = pinned.contains(model2.id)
|
||||||
|
let isDownloading1 = currentlyDownloading == model1.id
|
||||||
|
let isDownloading2 = currentlyDownloading == model2.id
|
||||||
|
|
||||||
|
// 1. Currently downloading models have highest priority
|
||||||
|
if isDownloading1 != isDownloading2 {
|
||||||
|
return isDownloading1
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Pinned models have second priority
|
||||||
if isPinned1 != isPinned2 {
|
if isPinned1 != isPinned2 {
|
||||||
return isPinned1
|
return isPinned1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3. If both are pinned, sort by pin time
|
||||||
if isPinned1 && isPinned2 {
|
if isPinned1 && isPinned2 {
|
||||||
let index1 = pinned.firstIndex(of: model1.modelId)!
|
let index1 = pinned.firstIndex(of: model1.id)!
|
||||||
let index2 = pinned.firstIndex(of: model2.modelId)!
|
let index2 = pinned.firstIndex(of: model2.id)!
|
||||||
return index1 > index2 // Pinned later comes first
|
return index1 > index2 // Pinned later comes first
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-pinned models
|
// 4. Non-pinned models sorted by download status
|
||||||
if model1.isDownloaded != model2.isDownloaded {
|
if model1.isDownloaded != model2.isDownloaded {
|
||||||
return model1.isDownloaded
|
return model1.isDownloaded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5. If both downloaded, sort by last used time
|
||||||
if model1.isDownloaded {
|
if model1.isDownloaded {
|
||||||
let date1 = model1.lastUsedAt ?? .distantPast
|
let date1 = model1.lastUsedAt ?? .distantPast
|
||||||
let date2 = model2.lastUsedAt ?? .distantPast
|
let date2 = model2.lastUsedAt ?? .distantPast
|
||||||
|
@ -145,10 +177,10 @@ class ModelListViewModel: ObservableObject {
|
||||||
|
|
||||||
return false // Keep original order for not-downloaded
|
return false // Keep original order for not-downloaded
|
||||||
}
|
}
|
||||||
|
|
||||||
models = fetchedModels
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Selection & Usage
|
||||||
|
|
||||||
func selectModel(_ model: ModelInfo) {
|
func selectModel(_ model: ModelInfo) {
|
||||||
if model.isDownloaded {
|
if model.isDownloaded {
|
||||||
selectedModel = model
|
selectedModel = model
|
||||||
|
@ -159,26 +191,35 @@ class ModelListViewModel: ObservableObject {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func recordModelUsage(modelName: String) {
|
||||||
|
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
||||||
|
if let index = models.firstIndex(where: { $0.modelName == modelName }) {
|
||||||
|
models[index].lastUsedAt = Date()
|
||||||
|
sortModels(fetchedModels: &models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Download Management
|
||||||
|
|
||||||
func downloadModel(_ model: ModelInfo) async {
|
func downloadModel(_ model: ModelInfo) async {
|
||||||
guard currentlyDownloading == nil else { return }
|
guard currentlyDownloading == nil else { return }
|
||||||
|
|
||||||
currentlyDownloading = model.modelId
|
currentlyDownloading = model.id
|
||||||
downloadProgress[model.modelId] = 0
|
downloadProgress[model.id] = 0
|
||||||
|
|
||||||
do {
|
do {
|
||||||
try await modelClient.downloadModel(model: model) { progress in
|
try await modelClient.downloadModel(model: model) { progress in
|
||||||
Task { @MainActor in
|
Task { @MainActor in
|
||||||
self.downloadProgress[model.modelId] = progress
|
self.downloadProgress[model.id] = progress
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) {
|
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
||||||
models[index].isDownloaded = true
|
models[index].isDownloaded = true
|
||||||
ModelStorageManager.shared.markModelAsDownloaded(model.name)
|
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch {
|
} catch {
|
||||||
|
|
||||||
if case ModelScopeError.downloadCancelled = error {
|
if case ModelScopeError.downloadCancelled = error {
|
||||||
print("Download was cancelled")
|
print("Download was cancelled")
|
||||||
} else {
|
} else {
|
||||||
|
@ -188,7 +229,7 @@ class ModelListViewModel: ObservableObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
currentlyDownloading = nil
|
currentlyDownloading = nil
|
||||||
downloadProgress.removeValue(forKey: model.modelId)
|
downloadProgress.removeValue(forKey: model.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cancelDownload() async {
|
func cancelDownload() async {
|
||||||
|
@ -201,24 +242,28 @@ class ModelListViewModel: ObservableObject {
|
||||||
print("Download cancelled for model: \(modelId)")
|
print("Download cancelled for model: \(modelId)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Pin Management
|
||||||
|
|
||||||
func pinModel(_ model: ModelInfo) {
|
func pinModel(_ model: ModelInfo) {
|
||||||
guard let index = models.firstIndex(where: { $0.modelId == model.modelId }) else { return }
|
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||||
let pinned = models.remove(at: index)
|
let pinned = models.remove(at: index)
|
||||||
models.insert(pinned, at: 0)
|
models.insert(pinned, at: 0)
|
||||||
var ids = pinnedModelIds.filter { $0 != model.modelId }
|
var ids = pinnedModelIds.filter { $0 != model.id }
|
||||||
ids.append(model.modelId)
|
ids.append(model.id)
|
||||||
pinnedModelIds = ids
|
pinnedModelIds = ids
|
||||||
}
|
}
|
||||||
|
|
||||||
func unpinModel(_ model: ModelInfo) {
|
func unpinModel(_ model: ModelInfo) {
|
||||||
guard let index = models.firstIndex(where: { $0.modelId == model.modelId }) else { return }
|
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
||||||
let unpinned = models.remove(at: index)
|
let unpinned = models.remove(at: index)
|
||||||
let insertIndex = models.count // 取消置顶后放到未置顶最后
|
let insertIndex = models.count // Insert at end after unpinning
|
||||||
models.insert(unpinned, at: insertIndex)
|
models.insert(unpinned, at: insertIndex)
|
||||||
pinnedModelIds = pinnedModelIds.filter { $0 != model.modelId }
|
pinnedModelIds = pinnedModelIds.filter { $0 != model.id }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Deletion
|
||||||
|
|
||||||
func deleteModel(_ model: ModelInfo) async {
|
func deleteModel(_ model: ModelInfo) async {
|
||||||
do {
|
do {
|
||||||
let fileManager = FileManager.default
|
let fileManager = FileManager.default
|
||||||
|
@ -240,11 +285,11 @@ class ModelListViewModel: ObservableObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
await MainActor.run {
|
await MainActor.run {
|
||||||
if let index = models.firstIndex(where: { $0.modelId == model.modelId }) {
|
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
||||||
models[index].isDownloaded = false
|
models[index].isDownloaded = false
|
||||||
ModelStorageManager.shared.clearDownloadStatus(for: model.name)
|
ModelStorageManager.shared.clearDownloadStatus(for: model.modelName)
|
||||||
}
|
}
|
||||||
if selectedModel?.modelId == model.modelId {
|
if selectedModel?.id == model.id {
|
||||||
selectedModel = nil
|
selectedModel = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
//
|
||||||
|
// TBDataResponse.swift
|
||||||
|
// MNNLLMiOS
|
||||||
|
//
|
||||||
|
// Created by 游薪渝(揽清) on 2025/7/9.
|
||||||
|
//
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
struct TBDataResponse: Codable {
|
||||||
|
let tagTranslations: [String: String]
|
||||||
|
let quickFilterTags: [String]?
|
||||||
|
let models: [ModelInfo]
|
||||||
|
let metadata: Metadata?
|
||||||
|
|
||||||
|
struct Metadata: Codable {
|
||||||
|
let version: String
|
||||||
|
let lastUpdated: String
|
||||||
|
let schemaVersion: String
|
||||||
|
let totalModels: Int
|
||||||
|
let supportedPlatforms: [String]
|
||||||
|
let minAppVersion: String
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,184 +0,0 @@
|
||||||
//
|
|
||||||
// TBModelInfo.swift
|
|
||||||
// MNNLLMiOS
|
|
||||||
//
|
|
||||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
|
||||||
//
|
|
||||||
|
|
||||||
import Hub
|
|
||||||
import Foundation
|
|
||||||
|
|
||||||
struct TBModelInfo: Codable {
|
|
||||||
let modelName: String
|
|
||||||
let tags: [String]
|
|
||||||
let categories: [String]?
|
|
||||||
let size_gb: Double?
|
|
||||||
let vendor: String?
|
|
||||||
let sources: [String: String]?
|
|
||||||
let tagTranslations: [String: [String]]?
|
|
||||||
|
|
||||||
// 运行时属性
|
|
||||||
var isDownloaded: Bool = false
|
|
||||||
var lastUsedAt: Date?
|
|
||||||
var cachedSize: Int64? = nil
|
|
||||||
|
|
||||||
// 统一的ID属性,根据当前选择的源获取对应的modelId
|
|
||||||
var id: String {
|
|
||||||
guard let sources = sources else {
|
|
||||||
return "taobao-mnn/\(modelName)"
|
|
||||||
}
|
|
||||||
|
|
||||||
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
|
|
||||||
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 本地化的标签 - 使用模型自己的tagTranslations
|
|
||||||
var localizedTags: [String] {
|
|
||||||
let currentLanguage = LanguageManager.shared.currentLanguage
|
|
||||||
let isChineseLanguage = currentLanguage == "简体中文"
|
|
||||||
|
|
||||||
if isChineseLanguage, let translations = tagTranslations {
|
|
||||||
let languageCode = "zh-Hans"
|
|
||||||
return translations[languageCode] ?? tags
|
|
||||||
} else {
|
|
||||||
return tags
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var localPath: String {
|
|
||||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: id)).path
|
|
||||||
}
|
|
||||||
|
|
||||||
var formattedSize: String {
|
|
||||||
if isDownloaded {
|
|
||||||
return formatLocalSize()
|
|
||||||
} else if let cached = cachedSize {
|
|
||||||
return formatBytes(cached)
|
|
||||||
} else if let sizeGb = size_gb {
|
|
||||||
return String(format: "%.1f GB", sizeGb)
|
|
||||||
} else {
|
|
||||||
return "计算中..."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchRemoteSize() async -> Int64? {
|
|
||||||
// TODO: now only support modelScope, support huggingFace later
|
|
||||||
let modelScopeId = "taobao-mnn/\(modelName)"
|
|
||||||
|
|
||||||
do {
|
|
||||||
let files = try await fetchFileList(repoPath: id, root: "", revision: "")
|
|
||||||
let totalSize = try await calculateTotalSize(files: files, repoPath: id)
|
|
||||||
return totalSize
|
|
||||||
} catch {
|
|
||||||
print("Error fetching remote size for \(id): \(error)")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func formatLocalSize() -> String {
|
|
||||||
let path = localPath
|
|
||||||
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
|
|
||||||
|
|
||||||
do {
|
|
||||||
let totalSize = try calculateDirectorySize(at: path)
|
|
||||||
return formatBytes(totalSize)
|
|
||||||
} catch {
|
|
||||||
return "未知"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
|
||||||
let fileManager = FileManager.default
|
|
||||||
var totalSize: Int64 = 0
|
|
||||||
|
|
||||||
let enumerator = fileManager.enumerator(atPath: path)
|
|
||||||
while let fileName = enumerator?.nextObject() as? String {
|
|
||||||
let filePath = (path as NSString).appendingPathComponent(fileName)
|
|
||||||
let attributes = try fileManager.attributesOfItem(atPath: filePath)
|
|
||||||
if let fileSize = attributes[.size] as? Int64 {
|
|
||||||
totalSize += fileSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalSize
|
|
||||||
}
|
|
||||||
|
|
||||||
private func formatBytes(_ bytes: Int64) -> String {
|
|
||||||
let formatter = ByteCountFormatter()
|
|
||||||
formatter.allowedUnits = [.useGB]
|
|
||||||
formatter.countStyle = .file
|
|
||||||
return formatter.string(fromByteCount: bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MARK: - 云端文件大小计算方法
|
|
||||||
|
|
||||||
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
|
||||||
let url = try buildURL(
|
|
||||||
repoPath: repoPath,
|
|
||||||
path: "/repo/files",
|
|
||||||
queryItems: [
|
|
||||||
URLQueryItem(name: "Root", value: root),
|
|
||||||
URLQueryItem(name: "Revision", value: revision)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
let (data, response) = try await URLSession.shared.data(from: url)
|
|
||||||
try validateResponse(response)
|
|
||||||
|
|
||||||
let modelResponse = try JSONDecoder().decode(ModelResponse.self, from: data)
|
|
||||||
return modelResponse.data.files
|
|
||||||
}
|
|
||||||
|
|
||||||
private func calculateTotalSize(files: [ModelFile], repoPath: String) async throws -> Int64 {
|
|
||||||
var totalSize: Int64 = 0
|
|
||||||
|
|
||||||
for file in files {
|
|
||||||
if file.type == "tree" {
|
|
||||||
let subFiles = try await fetchFileList(repoPath: repoPath, root: file.path, revision: "")
|
|
||||||
totalSize += try await calculateTotalSize(files: subFiles, repoPath: repoPath)
|
|
||||||
} else if file.type == "blob" {
|
|
||||||
totalSize += Int64(file.size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalSize
|
|
||||||
}
|
|
||||||
|
|
||||||
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
|
||||||
var components = URLComponents()
|
|
||||||
components.scheme = "https"
|
|
||||||
components.host = "modelscope.cn"
|
|
||||||
components.path = "/api/v1/models/\(repoPath)\(path)"
|
|
||||||
components.queryItems = queryItems
|
|
||||||
|
|
||||||
guard let url = components.url else {
|
|
||||||
throw ModelScopeError.invalidURL
|
|
||||||
}
|
|
||||||
return url
|
|
||||||
}
|
|
||||||
|
|
||||||
private func validateResponse(_ response: URLResponse) throws {
|
|
||||||
guard let httpResponse = response as? HTTPURLResponse,
|
|
||||||
(200...299).contains(httpResponse.statusCode) else {
|
|
||||||
throw ModelScopeError.invalidResponse
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private enum CodingKeys: String, CodingKey {
|
|
||||||
case modelName
|
|
||||||
case tags
|
|
||||||
case categories
|
|
||||||
case size_gb
|
|
||||||
case vendor
|
|
||||||
case sources
|
|
||||||
case tagTranslations
|
|
||||||
case cachedSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新TagTranslationManager以支持单个标签翻译
|
|
||||||
//extension TagTranslationManager {
|
|
||||||
// func getTranslation(for tag: String) -> String? {
|
|
||||||
// return globalTagTranslations[tag]
|
|
||||||
// }
|
|
||||||
//}
|
|
|
@ -1,296 +0,0 @@
|
||||||
//
|
|
||||||
// TBModelListViewModel.swift
|
|
||||||
// MNNLLMiOS
|
|
||||||
//
|
|
||||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
|
||||||
//
|
|
||||||
|
|
||||||
import Foundation
|
|
||||||
import SwiftUI
|
|
||||||
|
|
||||||
@MainActor
|
|
||||||
class TBModelListViewModel: ObservableObject {
|
|
||||||
@Published var models: [TBModelInfo] = []
|
|
||||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
|
||||||
@Published private(set) var currentlyDownloading: String?
|
|
||||||
@Published var showError = false
|
|
||||||
@Published var errorMessage = ""
|
|
||||||
@Published var searchText = ""
|
|
||||||
@Published var quickFilterTags: [String] = []
|
|
||||||
|
|
||||||
@Published var selectedModel: TBModelInfo?
|
|
||||||
|
|
||||||
private let modelClient = TBModelClient()
|
|
||||||
private let pinnedModelKey = "com.mnnllm.pinnedModelIds"
|
|
||||||
|
|
||||||
public var pinnedModelIds: [String] {
|
|
||||||
get { UserDefaults.standard.stringArray(forKey: pinnedModelKey) ?? [] }
|
|
||||||
set { UserDefaults.standard.setValue(newValue, forKey: pinnedModelKey) }
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取所有可用的标签
|
|
||||||
var allTags: [String] {
|
|
||||||
let allTags = Set(models.flatMap { $0.tags })
|
|
||||||
return Array(allTags)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取所有可用的分类
|
|
||||||
var allCategories: [String] {
|
|
||||||
let allCategories = Set(models.compactMap { $0.categories }.flatMap { $0 })
|
|
||||||
return Array(allCategories)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取所有可用的厂商
|
|
||||||
var allVendors: [String] {
|
|
||||||
let allVendors = Set(models.compactMap { $0.vendor })
|
|
||||||
return Array(allVendors)
|
|
||||||
}
|
|
||||||
|
|
||||||
var filteredModels: [TBModelInfo] {
|
|
||||||
let filteredModels = searchText.isEmpty ? models : models.filter { model in
|
|
||||||
model.id.localizedCaseInsensitiveContains(searchText) ||
|
|
||||||
model.modelName.localizedCaseInsensitiveContains(searchText) ||
|
|
||||||
model.localizedTags.contains { $0.localizedCaseInsensitiveContains(searchText) }
|
|
||||||
}
|
|
||||||
|
|
||||||
let downloadedModels = filteredModels.filter { $0.isDownloaded }
|
|
||||||
let notDownloadedModels = filteredModels.filter { !$0.isDownloaded }
|
|
||||||
|
|
||||||
return downloadedModels + notDownloadedModels
|
|
||||||
}
|
|
||||||
|
|
||||||
init() {
|
|
||||||
Task {
|
|
||||||
await fetchModels()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchModels() async {
|
|
||||||
do {
|
|
||||||
let info = try await modelClient.getModelInfo()
|
|
||||||
|
|
||||||
self.quickFilterTags = info.quickFilterTags ?? []
|
|
||||||
|
|
||||||
TagTranslationManager.shared.loadTagTranslations(info.tagTranslations)
|
|
||||||
|
|
||||||
var fetchedModels = info.models
|
|
||||||
|
|
||||||
self.filteDiffusionModel(fetchedModels: &fetchedModels)
|
|
||||||
self.sortModels(fetchedModels: &fetchedModels)
|
|
||||||
self.models = fetchedModels
|
|
||||||
|
|
||||||
// 异步获取未下载模型的大小信息
|
|
||||||
Task {
|
|
||||||
await fetchModelSizes(for: fetchedModels)
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch {
|
|
||||||
showError = true
|
|
||||||
errorMessage = "Error: \(error.localizedDescription)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func fetchModelSizes(for models: [TBModelInfo]) async {
|
|
||||||
await withTaskGroup(of: Void.self) { group in
|
|
||||||
for (_, model) in models.enumerated() {
|
|
||||||
if !model.isDownloaded && model.cachedSize == nil && model.size_gb == nil {
|
|
||||||
group.addTask {
|
|
||||||
if let size = await model.fetchRemoteSize() {
|
|
||||||
await MainActor.run {
|
|
||||||
// 查找当前模型在实际数组中的索引
|
|
||||||
if let modelIndex = self.models.firstIndex(where: { $0.id == model.id }) {
|
|
||||||
self.models[modelIndex].cachedSize = size
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func recordModelUsage(modelName: String) {
|
|
||||||
ModelStorageManager.shared.updateLastUsed(for: modelName)
|
|
||||||
if let index = models.firstIndex(where: { $0.modelName == modelName }) {
|
|
||||||
models[index].lastUsedAt = Date()
|
|
||||||
sortModels(fetchedModels: &models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func filteDiffusionModel(fetchedModels: inout [TBModelInfo]) {
|
|
||||||
let hasDiffusionModels = fetchedModels.contains {
|
|
||||||
$0.modelName.lowercased().contains("diffusion")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasDiffusionModels {
|
|
||||||
fetchedModels = fetchedModels.filter { model in
|
|
||||||
let name = model.modelName.lowercased()
|
|
||||||
let tags = model.tags.map { $0.lowercased() }
|
|
||||||
|
|
||||||
// only show gpu diffusion
|
|
||||||
if name.contains("diffusion") {
|
|
||||||
return name.contains("gpu") || tags.contains { $0.contains("gpu") }
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in 0..<fetchedModels.count {
|
|
||||||
let model = fetchedModels[i]
|
|
||||||
fetchedModels[i].isDownloaded = ModelStorageManager.shared.isModelDownloaded(model.modelName)
|
|
||||||
fetchedModels[i].lastUsedAt = ModelStorageManager.shared.getLastUsed(for: model.modelName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func sortModels(fetchedModels: inout [TBModelInfo]) {
|
|
||||||
let pinned = pinnedModelIds
|
|
||||||
|
|
||||||
fetchedModels.sort { (model1, model2) -> Bool in
|
|
||||||
let isPinned1 = pinned.contains(model1.id)
|
|
||||||
let isPinned2 = pinned.contains(model2.id)
|
|
||||||
let isDownloading1 = currentlyDownloading == model1.id
|
|
||||||
let isDownloading2 = currentlyDownloading == model2.id
|
|
||||||
|
|
||||||
// 1. 正在下载的模型优先级最高
|
|
||||||
if isDownloading1 != isDownloading2 {
|
|
||||||
return isDownloading1
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 置顶的模型次优先级
|
|
||||||
if isPinned1 != isPinned2 {
|
|
||||||
return isPinned1
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 如果都是置顶的,按置顶时间排序
|
|
||||||
if isPinned1 && isPinned2 {
|
|
||||||
let index1 = pinned.firstIndex(of: model1.id)!
|
|
||||||
let index2 = pinned.firstIndex(of: model2.id)!
|
|
||||||
return index1 > index2 // Pinned later comes first
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 非置顶模型按下载状态排序
|
|
||||||
if model1.isDownloaded != model2.isDownloaded {
|
|
||||||
return model1.isDownloaded
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 如果都已下载,按最后使用时间排序
|
|
||||||
if model1.isDownloaded {
|
|
||||||
let date1 = model1.lastUsedAt ?? .distantPast
|
|
||||||
let date2 = model2.lastUsedAt ?? .distantPast
|
|
||||||
return date1 > date2
|
|
||||||
}
|
|
||||||
|
|
||||||
return false // Keep original order for not-downloaded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectModel(_ model: TBModelInfo) {
|
|
||||||
if model.isDownloaded {
|
|
||||||
selectedModel = model
|
|
||||||
} else {
|
|
||||||
Task {
|
|
||||||
await downloadModel(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func downloadModel(_ model: TBModelInfo) async {
|
|
||||||
guard currentlyDownloading == nil else { return }
|
|
||||||
|
|
||||||
currentlyDownloading = model.id
|
|
||||||
downloadProgress[model.id] = 0
|
|
||||||
|
|
||||||
do {
|
|
||||||
try await modelClient.downloadModel(model: model) { progress in
|
|
||||||
Task { @MainActor in
|
|
||||||
self.downloadProgress[model.id] = progress
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
|
||||||
models[index].isDownloaded = true
|
|
||||||
ModelStorageManager.shared.markModelAsDownloaded(model.modelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch {
|
|
||||||
|
|
||||||
if case ModelScopeError.downloadCancelled = error {
|
|
||||||
print("Download was cancelled")
|
|
||||||
} else {
|
|
||||||
showError = true
|
|
||||||
errorMessage = "Failed to download model: \(error.localizedDescription)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
currentlyDownloading = nil
|
|
||||||
downloadProgress.removeValue(forKey: model.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cancelDownload() async {
|
|
||||||
if let modelId = currentlyDownloading {
|
|
||||||
await modelClient.cancelDownload()
|
|
||||||
|
|
||||||
downloadProgress.removeValue(forKey: modelId)
|
|
||||||
currentlyDownloading = nil
|
|
||||||
|
|
||||||
print("Download cancelled for model: \(modelId)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func pinModel(_ model: TBModelInfo) {
|
|
||||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
|
||||||
let pinned = models.remove(at: index)
|
|
||||||
models.insert(pinned, at: 0)
|
|
||||||
var ids = pinnedModelIds.filter { $0 != model.id }
|
|
||||||
ids.append(model.id)
|
|
||||||
pinnedModelIds = ids
|
|
||||||
}
|
|
||||||
|
|
||||||
func unpinModel(_ model: TBModelInfo) {
|
|
||||||
guard let index = models.firstIndex(where: { $0.id == model.id }) else { return }
|
|
||||||
let unpinned = models.remove(at: index)
|
|
||||||
let insertIndex = models.count // 取消置顶后放到未置顶最后
|
|
||||||
models.insert(unpinned, at: insertIndex)
|
|
||||||
pinnedModelIds = pinnedModelIds.filter { $0 != model.id }
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteModel(_ model: TBModelInfo) async {
|
|
||||||
do {
|
|
||||||
let fileManager = FileManager.default
|
|
||||||
let modelPath = URL.init(filePath: model.localPath)
|
|
||||||
|
|
||||||
if let files = try? fileManager.contentsOfDirectory(
|
|
||||||
at: modelPath,
|
|
||||||
includingPropertiesForKeys: nil,
|
|
||||||
options: [.skipsHiddenFiles]
|
|
||||||
) {
|
|
||||||
let storage = ModelDownloadStorage()
|
|
||||||
for file in files {
|
|
||||||
storage.clearFileStatus(at: file.path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if fileManager.fileExists(atPath: modelPath.path) {
|
|
||||||
try fileManager.removeItem(at: modelPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
await MainActor.run {
|
|
||||||
if let index = models.firstIndex(where: { $0.id == model.id }) {
|
|
||||||
models[index].isDownloaded = false
|
|
||||||
ModelStorageManager.shared.clearDownloadStatus(for: model.modelName)
|
|
||||||
}
|
|
||||||
if selectedModel?.id == model.id {
|
|
||||||
selectedModel = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch {
|
|
||||||
print("Error deleting model: \(error)")
|
|
||||||
await MainActor.run {
|
|
||||||
self.errorMessage = "Failed to delete model: \(error.localizedDescription)"
|
|
||||||
self.showError = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,7 +2,7 @@
|
||||||
// ModelClient.swift
|
// ModelClient.swift
|
||||||
// MNNLLMiOS
|
// MNNLLMiOS
|
||||||
//
|
//
|
||||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||||
//
|
//
|
||||||
|
|
||||||
import Hub
|
import Hub
|
||||||
|
@ -26,17 +26,41 @@ class ModelClient {
|
||||||
|
|
||||||
init() {}
|
init() {}
|
||||||
|
|
||||||
func getModelList() async throws -> [ModelInfo] {
|
func getModelInfo() async throws -> TBDataResponse {
|
||||||
let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
|
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||||
return try await performRequest(url: url, retries: maxRetries)
|
throw NetworkError.invalidData
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = try Data(contentsOf: url)
|
||||||
|
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||||
|
return mockResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRepoInfo(repoName: String, revision: String) async throws -> RepoInfo {
|
func getModelList() async throws -> [ModelInfo] {
|
||||||
let url = URL(string: "\(baseURLString)/api/models/\(repoName)")!
|
// TODO: get json from network
|
||||||
return try await performRequest(url: url, retries: maxRetries)
|
// let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
|
||||||
|
// return try await performRequest(url: url, retries: maxRetries)
|
||||||
|
|
||||||
|
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
||||||
|
throw NetworkError.invalidData
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = try Data(contentsOf: url)
|
||||||
|
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
||||||
|
|
||||||
|
// 加载全局标签翻译
|
||||||
|
TagTranslationManager.shared.loadTagTranslations(mockResponse.tagTranslations)
|
||||||
|
|
||||||
|
return mockResponse.models
|
||||||
}
|
}
|
||||||
|
|
||||||
@MainActor
|
/**
|
||||||
|
* Downloads a model from the selected source with progress tracking
|
||||||
|
*
|
||||||
|
* @param model The ModelInfo object containing model details
|
||||||
|
* @param progress Progress callback that receives download progress (0.0 to 1.0)
|
||||||
|
* @throws Various network or file system errors
|
||||||
|
*/
|
||||||
func downloadModel(model: ModelInfo,
|
func downloadModel(model: ModelInfo,
|
||||||
progress: @escaping (Double) -> Void) async throws {
|
progress: @escaping (Double) -> Void) async throws {
|
||||||
switch ModelSourceManager.shared.selectedSource {
|
switch ModelSourceManager.shared.selectedSource {
|
||||||
|
@ -47,7 +71,9 @@ class ModelClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@MainActor
|
/**
|
||||||
|
* Cancels the current download operation
|
||||||
|
*/
|
||||||
func cancelDownload() async {
|
func cancelDownload() async {
|
||||||
if let manager = currentDownloadManager {
|
if let manager = currentDownloadManager {
|
||||||
await manager.cancelDownload()
|
await manager.cancelDownload()
|
||||||
|
@ -55,10 +81,16 @@ class ModelClient {
|
||||||
print("Download cancelled")
|
print("Download cancelled")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Downloads model from ModelScope platform
|
||||||
|
*
|
||||||
|
* @param model The ModelInfo object to download
|
||||||
|
* @param progress Progress callback for download updates
|
||||||
|
* @throws Download or network related errors
|
||||||
|
*/
|
||||||
private func downloadFromModelScope(_ model: ModelInfo,
|
private func downloadFromModelScope(_ model: ModelInfo,
|
||||||
progress: @escaping (Double) -> Void) async throws {
|
progress: @escaping (Double) -> Void) async throws {
|
||||||
let ModelScopeId = model.modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
|
let ModelScopeId = model.id
|
||||||
let config = URLSessionConfiguration.default
|
let config = URLSessionConfiguration.default
|
||||||
config.timeoutIntervalForRequest = 30
|
config.timeoutIntervalForRequest = 30
|
||||||
config.timeoutIntervalForResource = 300
|
config.timeoutIntervalForResource = 300
|
||||||
|
@ -66,56 +98,68 @@ class ModelClient {
|
||||||
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
|
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
|
||||||
currentDownloadManager = manager
|
currentDownloadManager = manager
|
||||||
|
|
||||||
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.name) { fileProgress in
|
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
|
||||||
progress(fileProgress)
|
Task { @MainActor in
|
||||||
|
progress(fileProgress)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
currentDownloadManager = nil
|
currentDownloadManager = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Downloads model from HuggingFace platform with optimized progress updates
|
||||||
|
*
|
||||||
|
* This method implements throttling to prevent UI stuttering by limiting
|
||||||
|
* progress update frequency and filtering out minor progress changes.
|
||||||
|
*
|
||||||
|
* @param model The ModelInfo object to download
|
||||||
|
* @param progress Progress callback for download updates
|
||||||
|
* @throws Download or network related errors
|
||||||
|
*/
|
||||||
private func downloadFromHuggingFace(_ model: ModelInfo,
|
private func downloadFromHuggingFace(_ model: ModelInfo,
|
||||||
progress: @escaping (Double) -> Void) async throws {
|
progress: @escaping (Double) -> Void) async throws {
|
||||||
let repo = Hub.Repo(id: model.modelId)
|
let repo = Hub.Repo(id: model.id)
|
||||||
let modelFiles = ["*.*"]
|
let modelFiles = ["*.*"]
|
||||||
let mirrorHubApi = HubApi(endpoint: baseURL)
|
let mirrorHubApi = HubApi(endpoint: baseURL)
|
||||||
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
|
|
||||||
progress(fileProgress.fractionCompleted)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func performRequest<T: Decodable>(url: URL, retries: Int = 3) async throws -> T {
|
|
||||||
var lastError: Error?
|
|
||||||
|
|
||||||
for attempt in 1...retries {
|
// Progress throttling mechanism to prevent UI stuttering
|
||||||
do {
|
var lastUpdateTime = Date()
|
||||||
var request = URLRequest(url: url)
|
var lastProgress: Double = 0.0
|
||||||
request.setValue("application/json", forHTTPHeaderField: "Accept")
|
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
|
||||||
|
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
|
||||||
|
|
||||||
|
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
|
||||||
|
let currentProgress = fileProgress.fractionCompleted
|
||||||
|
let currentTime = Date()
|
||||||
|
|
||||||
|
// Check if progress should be updated
|
||||||
|
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
|
||||||
|
let progressDiff = abs(currentProgress - lastProgress)
|
||||||
|
|
||||||
|
// Update progress if any of these conditions are met:
|
||||||
|
// 1. Time interval exceeds threshold
|
||||||
|
// 2. Progress change exceeds threshold
|
||||||
|
// 3. Progress reaches 100% (download complete)
|
||||||
|
// 4. Progress is 0% (download start)
|
||||||
|
if timeDiff >= progressUpdateInterval ||
|
||||||
|
progressDiff >= progressThreshold ||
|
||||||
|
currentProgress >= 1.0 ||
|
||||||
|
currentProgress == 0.0 {
|
||||||
|
|
||||||
let (data, response) = try await URLSession.shared.data(for: request)
|
lastUpdateTime = currentTime
|
||||||
|
lastProgress = currentProgress
|
||||||
|
|
||||||
guard let httpResponse = response as? HTTPURLResponse else {
|
// Ensure progress updates are executed on the main thread
|
||||||
throw NetworkError.invalidResponse
|
Task { @MainActor in
|
||||||
}
|
progress(currentProgress)
|
||||||
|
|
||||||
if httpResponse.statusCode == 200 {
|
|
||||||
return try JSONDecoder().decode(T.self, from: data)
|
|
||||||
}
|
|
||||||
|
|
||||||
throw NetworkError.invalidResponse
|
|
||||||
|
|
||||||
} catch {
|
|
||||||
lastError = error
|
|
||||||
if attempt < retries {
|
|
||||||
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(attempt)) * 1_000_000_000))
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
throw lastError ?? NetworkError.unknown
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
enum NetworkError: Error {
|
enum NetworkError: Error {
|
||||||
case invalidResponse
|
case invalidResponse
|
||||||
case invalidData
|
case invalidData
|
||||||
|
|
|
@ -1,176 +0,0 @@
|
||||||
//
|
|
||||||
// TBModelClient.swift
|
|
||||||
// MNNLLMiOS
|
|
||||||
//
|
|
||||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
|
||||||
//
|
|
||||||
|
|
||||||
import Hub
|
|
||||||
import Foundation
|
|
||||||
|
|
||||||
class TBModelClient {
|
|
||||||
private let baseMirrorURL = "https://hf-mirror.com"
|
|
||||||
private let baseURL = "https://huggingface.co"
|
|
||||||
private let maxRetries = 5
|
|
||||||
|
|
||||||
private var currentDownloadManager: ModelScopeDownloadManager?
|
|
||||||
|
|
||||||
private lazy var baseURLString: String = {
|
|
||||||
switch ModelSourceManager.shared.selectedSource {
|
|
||||||
case .huggingFace:
|
|
||||||
return baseURL
|
|
||||||
default:
|
|
||||||
return baseMirrorURL
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
init() {}
|
|
||||||
|
|
||||||
func getModelInfo() async throws -> TBDataResponse {
|
|
||||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
|
||||||
throw NetworkError.invalidData
|
|
||||||
}
|
|
||||||
|
|
||||||
let data = try Data(contentsOf: url)
|
|
||||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
|
||||||
return mockResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func getModelList() async throws -> [TBModelInfo] {
|
|
||||||
// TODO: get json from network
|
|
||||||
// let url = URL(string: "\(baseURLString)/api/models?author=taobao-mnn&limit=100")!
|
|
||||||
// return try await performRequest(url: url, retries: maxRetries)
|
|
||||||
|
|
||||||
guard let url = Bundle.main.url(forResource: "mock", withExtension: "json") else {
|
|
||||||
throw NetworkError.invalidData
|
|
||||||
}
|
|
||||||
|
|
||||||
let data = try Data(contentsOf: url)
|
|
||||||
let mockResponse = try JSONDecoder().decode(TBDataResponse.self, from: data)
|
|
||||||
|
|
||||||
// 加载全局标签翻译
|
|
||||||
TagTranslationManager.shared.loadTagTranslations(mockResponse.tagTranslations)
|
|
||||||
|
|
||||||
return mockResponse.models
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Downloads a model from the selected source with progress tracking
|
|
||||||
*
|
|
||||||
* @param model The ModelInfo object containing model details
|
|
||||||
* @param progress Progress callback that receives download progress (0.0 to 1.0)
|
|
||||||
* @throws Various network or file system errors
|
|
||||||
*/
|
|
||||||
func downloadModel(model: TBModelInfo,
|
|
||||||
progress: @escaping (Double) -> Void) async throws {
|
|
||||||
switch ModelSourceManager.shared.selectedSource {
|
|
||||||
case .modelScope, .modeler:
|
|
||||||
try await downloadFromModelScope(model, progress: progress)
|
|
||||||
case .huggingFace:
|
|
||||||
try await downloadFromHuggingFace(model, progress: progress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cancels the current download operation
|
|
||||||
*/
|
|
||||||
func cancelDownload() async {
|
|
||||||
if let manager = currentDownloadManager {
|
|
||||||
await manager.cancelDownload()
|
|
||||||
currentDownloadManager = nil
|
|
||||||
print("Download cancelled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Downloads model from ModelScope platform
|
|
||||||
*
|
|
||||||
* @param model The ModelInfo object to download
|
|
||||||
* @param progress Progress callback for download updates
|
|
||||||
* @throws Download or network related errors
|
|
||||||
*/
|
|
||||||
private func downloadFromModelScope(_ model: TBModelInfo,
|
|
||||||
progress: @escaping (Double) -> Void) async throws {
|
|
||||||
let ModelScopeId = model.id
|
|
||||||
let config = URLSessionConfiguration.default
|
|
||||||
config.timeoutIntervalForRequest = 30
|
|
||||||
config.timeoutIntervalForResource = 300
|
|
||||||
|
|
||||||
let manager = ModelScopeDownloadManager.init(repoPath: ModelScopeId, config: config, enableLogging: true, source: ModelSourceManager.shared.selectedSource)
|
|
||||||
currentDownloadManager = manager
|
|
||||||
|
|
||||||
try await manager.downloadModel(to:"huggingface/models/taobao-mnn", modelId: ModelScopeId, modelName: model.modelName) { fileProgress in
|
|
||||||
Task { @MainActor in
|
|
||||||
progress(fileProgress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
currentDownloadManager = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Downloads model from HuggingFace platform with optimized progress updates
|
|
||||||
*
|
|
||||||
* This method implements throttling to prevent UI stuttering by limiting
|
|
||||||
* progress update frequency and filtering out minor progress changes.
|
|
||||||
*
|
|
||||||
* @param model The ModelInfo object to download
|
|
||||||
* @param progress Progress callback for download updates
|
|
||||||
* @throws Download or network related errors
|
|
||||||
*/
|
|
||||||
private func downloadFromHuggingFace(_ model: TBModelInfo,
|
|
||||||
progress: @escaping (Double) -> Void) async throws {
|
|
||||||
let repo = Hub.Repo(id: model.id)
|
|
||||||
let modelFiles = ["*.*"]
|
|
||||||
let mirrorHubApi = HubApi(endpoint: baseURL)
|
|
||||||
|
|
||||||
// Progress throttling mechanism to prevent UI stuttering
|
|
||||||
var lastUpdateTime = Date()
|
|
||||||
var lastProgress: Double = 0.0
|
|
||||||
let progressUpdateInterval: TimeInterval = 0.1 // Limit update frequency to every 100ms
|
|
||||||
let progressThreshold: Double = 0.01 // Progress change threshold of 1%
|
|
||||||
|
|
||||||
try await mirrorHubApi.snapshot(from: repo, matching: modelFiles) { fileProgress in
|
|
||||||
let currentProgress = fileProgress.fractionCompleted
|
|
||||||
let currentTime = Date()
|
|
||||||
|
|
||||||
// Check if progress should be updated
|
|
||||||
let timeDiff = currentTime.timeIntervalSince(lastUpdateTime)
|
|
||||||
let progressDiff = abs(currentProgress - lastProgress)
|
|
||||||
|
|
||||||
// Update progress if any of these conditions are met:
|
|
||||||
// 1. Time interval exceeds threshold
|
|
||||||
// 2. Progress change exceeds threshold
|
|
||||||
// 3. Progress reaches 100% (download complete)
|
|
||||||
// 4. Progress is 0% (download start)
|
|
||||||
if timeDiff >= progressUpdateInterval ||
|
|
||||||
progressDiff >= progressThreshold ||
|
|
||||||
currentProgress >= 1.0 ||
|
|
||||||
currentProgress == 0.0 {
|
|
||||||
|
|
||||||
lastUpdateTime = currentTime
|
|
||||||
lastProgress = currentProgress
|
|
||||||
|
|
||||||
// Ensure progress updates are executed on the main thread
|
|
||||||
Task { @MainActor in
|
|
||||||
progress(currentProgress)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TBDataResponse: Codable {
|
|
||||||
let tagTranslations: [String: String]
|
|
||||||
let quickFilterTags: [String]?
|
|
||||||
let models: [TBModelInfo]
|
|
||||||
let metadata: MockMetadata?
|
|
||||||
|
|
||||||
struct MockMetadata: Codable {
|
|
||||||
let version: String
|
|
||||||
let lastUpdated: String
|
|
||||||
let schemaVersion: String
|
|
||||||
let totalModels: Int
|
|
||||||
let supportedPlatforms: [String]
|
|
||||||
let minAppVersion: String
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -10,7 +10,7 @@ import SwiftUI
|
||||||
// MARK: - 筛选菜单视图
|
// MARK: - 筛选菜单视图
|
||||||
struct FilterMenuView: View {
|
struct FilterMenuView: View {
|
||||||
@Environment(\.dismiss) private var dismiss
|
@Environment(\.dismiss) private var dismiss
|
||||||
@StateObject private var viewModel = TBModelListViewModel()
|
@StateObject private var viewModel = ModelListViewModel()
|
||||||
@Binding var selectedTags: Set<String>
|
@Binding var selectedTags: Set<String>
|
||||||
@Binding var selectedCategories: Set<String>
|
@Binding var selectedCategories: Set<String>
|
||||||
@Binding var selectedVendors: Set<String>
|
@Binding var selectedVendors: Set<String>
|
||||||
|
|
|
@ -9,7 +9,7 @@ import SwiftUI
|
||||||
|
|
||||||
// MARK: - 工具栏视图
|
// MARK: - 工具栏视图
|
||||||
struct ToolbarView: View {
|
struct ToolbarView: View {
|
||||||
@ObservedObject var viewModel: TBModelListViewModel
|
@ObservedObject var viewModel: ModelListViewModel
|
||||||
@Binding var selectedSource: ModelSource
|
@Binding var selectedSource: ModelSource
|
||||||
@Binding var showSourceMenu: Bool
|
@Binding var showSourceMenu: Bool
|
||||||
@Binding var selectedTags: Set<String>
|
@Binding var selectedTags: Set<String>
|
||||||
|
|
|
@ -2,136 +2,146 @@
|
||||||
// ModelListView.swift
|
// ModelListView.swift
|
||||||
// MNNLLMiOS
|
// MNNLLMiOS
|
||||||
//
|
//
|
||||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||||
//
|
//
|
||||||
|
|
||||||
import SwiftUI
|
import SwiftUI
|
||||||
|
|
||||||
|
|
||||||
struct ModelListView: View {
|
struct ModelListView: View {
|
||||||
@ObservedObject var viewModel: ModelListViewModel
|
@ObservedObject var viewModel: ModelListViewModel
|
||||||
|
@State private var searchText = ""
|
||||||
@State private var scrollOffset: CGFloat = 0
|
|
||||||
@State private var showHelp = false
|
|
||||||
@State private var showUserGuide = false
|
|
||||||
|
|
||||||
@State private var downloadSources: ModelSource?
|
|
||||||
@State private var selectedSource = ModelSourceManager.shared.selectedSource
|
@State private var selectedSource = ModelSourceManager.shared.selectedSource
|
||||||
|
@State private var showSourceMenu = false
|
||||||
@State private var showOptions = false
|
@State private var selectedTags: Set<String> = []
|
||||||
@State private var buttonFrame: CGRect = .zero
|
@State private var selectedCategories: Set<String> = []
|
||||||
|
@State private var selectedVendors: Set<String> = []
|
||||||
|
@State private var showFilterMenu = false
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
ZStack {
|
ScrollView {
|
||||||
VStack {
|
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
|
||||||
HStack {
|
Section {
|
||||||
Button {
|
modelListSection
|
||||||
showOptions.toggle()
|
} header: {
|
||||||
} label: {
|
toolbarSection
|
||||||
HStack {
|
|
||||||
Text("下载源:")
|
|
||||||
.font(.system(size: 12, weight: .regular))
|
|
||||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
|
||||||
Text(selectedSource.rawValue)
|
|
||||||
.font(.system(size: 12, weight: .regular))
|
|
||||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
|
||||||
Image(systemName: "chevron.down")
|
|
||||||
.frame(width: 10, height: 10, alignment: .leading)
|
|
||||||
.scaledToFit()
|
|
||||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
|
||||||
}
|
|
||||||
.padding(.leading)
|
|
||||||
}
|
|
||||||
Spacer()
|
|
||||||
}
|
}
|
||||||
.frame(maxWidth: .infinity, maxHeight: 20)
|
|
||||||
.background(
|
|
||||||
GeometryReader { geometry in
|
|
||||||
Color.white.onAppear {
|
|
||||||
buttonFrame = geometry.frame(in: .global)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
List {
|
|
||||||
SearchBar(text: $viewModel.searchText)
|
|
||||||
.listRowInsets(EdgeInsets())
|
|
||||||
.listRowSeparator(.hidden)
|
|
||||||
.padding(.horizontal)
|
|
||||||
|
|
||||||
ForEach(viewModel.filteredModels, id: \.modelId) { model in
|
|
||||||
|
|
||||||
ModelRowView(model: model,
|
|
||||||
viewModel: viewModel,
|
|
||||||
downloadProgress: viewModel.downloadProgress[model.modelId] ?? 0,
|
|
||||||
isDownloading: viewModel.currentlyDownloading == model.modelId,
|
|
||||||
isOtherDownloading: viewModel.currentlyDownloading != nil) {
|
|
||||||
if model.isDownloaded {
|
|
||||||
viewModel.selectModel(model)
|
|
||||||
} else {
|
|
||||||
Task {
|
|
||||||
await viewModel.downloadModel(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.listRowSeparator(.hidden)
|
|
||||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.modelId) ? Color.black.opacity(0.05) : Color.clear)
|
|
||||||
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
|
||||||
SwipeActionsView(model: model, viewModel: viewModel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.listStyle(.plain)
|
|
||||||
.sheet(isPresented: $showHelp) {
|
|
||||||
HelpView()
|
|
||||||
}
|
|
||||||
.refreshable {
|
|
||||||
await viewModel.fetchModels()
|
|
||||||
}
|
|
||||||
.alert("Error", isPresented: $viewModel.showError) {
|
|
||||||
Button("OK", role: .cancel) {}
|
|
||||||
} message: {
|
|
||||||
Text(viewModel.errorMessage)
|
|
||||||
}
|
|
||||||
.onAppear {
|
|
||||||
checkFirstLaunch()
|
|
||||||
}
|
|
||||||
.alert(isPresented: $showUserGuide) {
|
|
||||||
Alert(
|
|
||||||
title: Text("User Guide"),
|
|
||||||
message: Text("""
|
|
||||||
This is a local large model application that requires certain performance from your device.
|
|
||||||
It is recommended to choose different model sizes based on your device's memory.
|
|
||||||
|
|
||||||
The model recommendations for iPhone are as follows:
|
|
||||||
- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).
|
|
||||||
- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).
|
|
||||||
- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).
|
|
||||||
|
|
||||||
Choosing a model that is too large may cause insufficient memory and crashes.
|
|
||||||
"""),
|
|
||||||
dismissButton: .default(Text("OK"))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
Spacer()
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
.searchable(text: $searchText, prompt: "Search models...")
|
||||||
|
.onChange(of: searchText) { _, newValue in
|
||||||
|
viewModel.searchText = newValue
|
||||||
|
}
|
||||||
|
.refreshable {
|
||||||
|
await viewModel.fetchModels()
|
||||||
|
}
|
||||||
|
.alert("Error", isPresented: $viewModel.showError) {
|
||||||
|
Button("OK") { }
|
||||||
|
} message: {
|
||||||
|
Text(viewModel.errorMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract model list section as independent view
|
||||||
|
@ViewBuilder
|
||||||
|
private var modelListSection: some View {
|
||||||
|
LazyVStack(spacing: 8) {
|
||||||
|
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
|
||||||
|
modelRowView(model: model, index: index)
|
||||||
|
|
||||||
|
if index < filteredModels.count - 1 {
|
||||||
|
Divider()
|
||||||
|
.padding(.horizontal, 16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.padding(.vertical, 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract toolbar section as independent view
|
||||||
|
@ViewBuilder
|
||||||
|
private var toolbarSection: some View {
|
||||||
|
ToolbarView(
|
||||||
|
viewModel: viewModel, selectedSource: $selectedSource,
|
||||||
|
showSourceMenu: $showSourceMenu,
|
||||||
|
selectedTags: $selectedTags,
|
||||||
|
selectedCategories: $selectedCategories,
|
||||||
|
selectedVendors: $selectedVendors,
|
||||||
|
quickFilterTags: viewModel.quickFilterTags,
|
||||||
|
showFilterMenu: $showFilterMenu,
|
||||||
|
onSourceChange: handleSourceChange
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract single model row view as independent method
|
||||||
|
@ViewBuilder
|
||||||
|
private func modelRowView(model: ModelInfo, index: Int) -> some View {
|
||||||
|
ModelRowView(
|
||||||
|
model: model,
|
||||||
|
viewModel: viewModel,
|
||||||
|
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
|
||||||
|
isDownloading: viewModel.currentlyDownloading == model.id,
|
||||||
|
isOtherDownloading: isOtherDownloadingCheck(model: model)
|
||||||
|
) {
|
||||||
|
Task {
|
||||||
|
await viewModel.downloadModel(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract complex boolean logic as independent method
|
||||||
|
private func isOtherDownloadingCheck(model: ModelInfo) -> Bool {
|
||||||
|
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract source change handling logic as independent method
|
||||||
|
private func handleSourceChange(_ source: ModelSource) {
|
||||||
|
ModelSourceManager.shared.updateSelectedSource(source)
|
||||||
|
selectedSource = source
|
||||||
|
Task {
|
||||||
|
await viewModel.fetchModels()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter models based on selected tags, categories and vendors
|
||||||
|
private var filteredModels: [ModelInfo] {
|
||||||
|
let baseFiltered = viewModel.filteredModels
|
||||||
|
|
||||||
|
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
|
||||||
|
return baseFiltered
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseFiltered.filter { model in
|
||||||
|
let tagMatch = checkTagMatch(model: model)
|
||||||
|
let categoryMatch = checkCategoryMatch(model: model)
|
||||||
|
let vendorMatch = checkVendorMatch(model: model)
|
||||||
|
|
||||||
if showOptions {
|
return tagMatch && categoryMatch && vendorMatch
|
||||||
CustomPopupMenu(isPresented: $showOptions,
|
}
|
||||||
selectedSource: $selectedSource,
|
}
|
||||||
anchorFrame: buttonFrame)
|
|
||||||
|
// Extract tag matching logic as independent method
|
||||||
|
private func checkTagMatch(model: ModelInfo) -> Bool {
|
||||||
|
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
|
||||||
|
model.localizedTags.contains { tag in
|
||||||
|
tag.localizedCaseInsensitiveContains(selectedTag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func checkFirstLaunch() {
|
// Extract category matching logic as independent method
|
||||||
let hasLaunchedBefore = UserDefaults.standard.bool(forKey: "hasLaunchedBefore")
|
private func checkCategoryMatch(model: ModelInfo) -> Bool {
|
||||||
if !hasLaunchedBefore {
|
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
|
||||||
// Show the user guide alert
|
model.categories?.contains { category in
|
||||||
showUserGuide = true
|
category.localizedCaseInsensitiveContains(selectedCategory)
|
||||||
// Set the flag to true so it doesn't show again
|
} ?? false
|
||||||
UserDefaults.standard.set(true, forKey: "hasLaunchedBefore")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract vendor matching logic as independent method
|
||||||
|
private func checkVendorMatch(model: ModelInfo) -> Bool {
|
||||||
|
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
|
||||||
|
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,8 @@ import SwiftUI
|
||||||
|
|
||||||
// MARK: - 操作按钮视图
|
// MARK: - 操作按钮视图
|
||||||
struct ActionButtonsView: View {
|
struct ActionButtonsView: View {
|
||||||
let model: TBModelInfo
|
let model: ModelInfo
|
||||||
@ObservedObject var viewModel: TBModelListViewModel
|
@ObservedObject var viewModel: ModelListViewModel
|
||||||
let downloadProgress: Double
|
let downloadProgress: Double
|
||||||
let isDownloading: Bool
|
let isDownloading: Bool
|
||||||
let isOtherDownloading: Bool
|
let isOtherDownloading: Bool
|
||||||
|
@ -40,4 +40,4 @@ struct ActionButtonsView: View {
|
||||||
}
|
}
|
||||||
.frame(width: 60)
|
.frame(width: 60)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import SwiftUI
|
||||||
|
|
||||||
// MARK: - 下载中按钮视图
|
// MARK: - 下载中按钮视图
|
||||||
struct DownloadingButtonView: View {
|
struct DownloadingButtonView: View {
|
||||||
@ObservedObject var viewModel: TBModelListViewModel
|
@ObservedObject var viewModel: ModelListViewModel
|
||||||
let downloadProgress: Double
|
let downloadProgress: Double
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
|
@ -29,4 +29,4 @@ struct DownloadingButtonView: View {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
// ModelRowView.swift
|
// ModelRowView.swift
|
||||||
// MNNLLMiOS
|
// MNNLLMiOS
|
||||||
//
|
//
|
||||||
// Created by 游薪渝(揽清) on 2025/1/3.
|
// Created by 游薪渝(揽清) on 2025/7/4.
|
||||||
//
|
//
|
||||||
|
|
||||||
import SwiftUI
|
import SwiftUI
|
||||||
|
@ -17,127 +17,78 @@ struct ModelRowView: View {
|
||||||
let isOtherDownloading: Bool
|
let isOtherDownloading: Bool
|
||||||
let onDownload: () -> Void
|
let onDownload: () -> Void
|
||||||
|
|
||||||
|
|
||||||
@State private var showDeleteAlert = false
|
@State private var showDeleteAlert = false
|
||||||
|
|
||||||
|
// 预计算本地化标签,避免重复计算
|
||||||
|
private var localizedTags: [String] {
|
||||||
|
model.localizedTags
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预计算格式化大小,避免重复计算
|
||||||
|
private var formattedSize: String {
|
||||||
|
model.formattedSize
|
||||||
|
}
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
HStack(alignment: .top) {
|
HStack(alignment: .top, spacing: 0) {
|
||||||
ModelIconView(modelId: model.modelId)
|
// 模型图标
|
||||||
|
ModelIconView(modelId: model.id)
|
||||||
.frame(width: 40, height: 40)
|
.frame(width: 40, height: 40)
|
||||||
|
|
||||||
VStack(alignment: .leading, spacing: 5) {
|
// 主要信息区域
|
||||||
Text(model.name)
|
VStack(alignment: .leading, spacing: 6) {
|
||||||
|
// 模型名称
|
||||||
|
Text(model.modelName)
|
||||||
.font(.headline)
|
.font(.headline)
|
||||||
.fontWeight(.semibold)
|
.fontWeight(.semibold)
|
||||||
.lineLimit(1)
|
.lineLimit(1)
|
||||||
|
|
||||||
if let lastUsedAt = model.lastUsedAt {
|
// 标签列表
|
||||||
Text("Last used: \(lastUsedAt.formatAgo())")
|
if !localizedTags.isEmpty {
|
||||||
.font(.caption2)
|
TagsView(tags: localizedTags)
|
||||||
.foregroundColor(.gray)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !model.tags.isEmpty {
|
|
||||||
ScrollView(.horizontal, showsIndicators: false) {
|
|
||||||
HStack {
|
|
||||||
ForEach(model.tags, id: \.self) { tag in
|
|
||||||
Text(tag)
|
|
||||||
.fontWeight(.regular)
|
|
||||||
.font(.caption)
|
|
||||||
.foregroundColor(.gray)
|
|
||||||
.padding(.horizontal, 8)
|
|
||||||
.padding(.vertical, 3)
|
|
||||||
.background(
|
|
||||||
RoundedRectangle(cornerRadius: 8)
|
|
||||||
.stroke(Color.gray.opacity(0.5), lineWidth: 0.5)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.frame(height: 25)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
.padding(.leading, 8)
|
||||||
|
|
||||||
Spacer()
|
Spacer()
|
||||||
|
|
||||||
VStack(alignment: .center, spacing: 4) {
|
ActionButtonsView(
|
||||||
if model.isDownloaded {
|
model: model,
|
||||||
Button(action: {
|
viewModel: viewModel,
|
||||||
showDeleteAlert = true
|
downloadProgress: downloadProgress,
|
||||||
}) {
|
isDownloading: isDownloading,
|
||||||
Image(systemName: "trash")
|
isOtherDownloading: isOtherDownloading,
|
||||||
.fontWeight(.regular)
|
formattedSize: formattedSize,
|
||||||
.foregroundColor(.black.opacity(0.8))
|
onDownload: onDownload,
|
||||||
.frame(width: 20, height: 20)
|
showDeleteAlert: $showDeleteAlert
|
||||||
|
)
|
||||||
Text("已下载")
|
|
||||||
.font(.caption2)
|
|
||||||
.foregroundColor(.gray)
|
|
||||||
.padding(.top, 4)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if isDownloading {
|
|
||||||
Button(action: {
|
|
||||||
Task {
|
|
||||||
await viewModel.cancelDownload()
|
|
||||||
}
|
|
||||||
}) {
|
|
||||||
ProgressView(value: downloadProgress)
|
|
||||||
.progressViewStyle(CircularProgressViewStyle())
|
|
||||||
.frame(width: 28, height: 28)
|
|
||||||
Text(String(format: "%.2f%%", downloadProgress * 100))
|
|
||||||
.font(.caption2)
|
|
||||||
.foregroundColor(.gray)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Button(action: onDownload) {
|
|
||||||
Image(systemName: "arrow.down.circle.fill")
|
|
||||||
.font(.title2)
|
|
||||||
}
|
|
||||||
.foregroundColor(isOtherDownloading ? .gray : .primaryPurple)
|
|
||||||
.disabled(isOtherDownloading)
|
|
||||||
|
|
||||||
HStack(alignment: .bottom, spacing: 2) {
|
|
||||||
Image(systemName: "folder")
|
|
||||||
.font(.caption2)
|
|
||||||
Text(model.formattedSize)
|
|
||||||
.font(.caption2)
|
|
||||||
.lineLimit(1)
|
|
||||||
.minimumScaleFactor(0.8)
|
|
||||||
.offset(y: 1)
|
|
||||||
.onAppear {
|
|
||||||
if !model.isDownloaded && model.cachedSize == nil {
|
|
||||||
Task {
|
|
||||||
if let size = await model.fetchRemoteSize() {
|
|
||||||
await MainActor.run {
|
|
||||||
if let index = viewModel.models.firstIndex(where: { $0.modelId == model.modelId }) {
|
|
||||||
viewModel.models[index].cachedSize = size
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.foregroundColor(.gray)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
.frame(width: 60)
|
|
||||||
}
|
}
|
||||||
.padding(.vertical, 8)
|
.padding(.vertical, 8)
|
||||||
.alert(isPresented: $showDeleteAlert) {
|
.contentShape(Rectangle()) // 确保整个区域都可以点击
|
||||||
Alert(
|
.onTapGesture {
|
||||||
title: Text("确认删除"),
|
handleRowTap()
|
||||||
message: Text("是否确认删除该模型?"),
|
}
|
||||||
primaryButton: .destructive(Text("删除")) {
|
.alert("确认删除", isPresented: $showDeleteAlert) {
|
||||||
Task {
|
Button("删除", role: .destructive) {
|
||||||
await viewModel.deleteModel(model)
|
Task {
|
||||||
}
|
await viewModel.deleteModel(model)
|
||||||
},
|
}
|
||||||
secondaryButton: .cancel(Text("取消"))
|
}
|
||||||
)
|
Button("取消", role: .cancel) { }
|
||||||
|
} message: {
|
||||||
|
Text("是否确认删除该模型?")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func handleRowTap() {
|
||||||
|
if model.isDownloaded {
|
||||||
|
return
|
||||||
|
} else if isDownloading {
|
||||||
|
Task {
|
||||||
|
await viewModel.cancelDownload()
|
||||||
|
}
|
||||||
|
} else if !isOtherDownloading {
|
||||||
|
onDownload()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ struct SwipeActionsView: View {
|
||||||
@ObservedObject var viewModel: ModelListViewModel
|
@ObservedObject var viewModel: ModelListViewModel
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
if viewModel.pinnedModelIds.contains(model.modelId) {
|
if viewModel.pinnedModelIds.contains(model.id) {
|
||||||
Button {
|
Button {
|
||||||
viewModel.unpinModel(model)
|
viewModel.unpinModel(model)
|
||||||
} label: {
|
} label: {
|
||||||
|
|
|
@ -1,147 +0,0 @@
|
||||||
//
|
|
||||||
// TBModelListView.swift
|
|
||||||
// MNNLLMiOS
|
|
||||||
//
|
|
||||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
|
||||||
//
|
|
||||||
|
|
||||||
import SwiftUI
|
|
||||||
|
|
||||||
struct TBModelListView: View {
|
|
||||||
@ObservedObject var viewModel: TBModelListViewModel
|
|
||||||
@State private var searchText = ""
|
|
||||||
@State private var selectedSource = ModelSourceManager.shared.selectedSource
|
|
||||||
@State private var showSourceMenu = false
|
|
||||||
@State private var selectedTags: Set<String> = []
|
|
||||||
@State private var selectedCategories: Set<String> = []
|
|
||||||
@State private var selectedVendors: Set<String> = []
|
|
||||||
@State private var showFilterMenu = false
|
|
||||||
|
|
||||||
var body: some View {
|
|
||||||
ScrollView {
|
|
||||||
LazyVStack(spacing: 0, pinnedViews: [.sectionHeaders]) {
|
|
||||||
Section {
|
|
||||||
modelListSection
|
|
||||||
} header: {
|
|
||||||
toolbarSection
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.searchable(text: $searchText, prompt: "Search models...")
|
|
||||||
.onChange(of: searchText) { _, newValue in
|
|
||||||
viewModel.searchText = newValue
|
|
||||||
}
|
|
||||||
.refreshable {
|
|
||||||
await viewModel.fetchModels()
|
|
||||||
}
|
|
||||||
.alert("Error", isPresented: $viewModel.showError) {
|
|
||||||
Button("OK") { }
|
|
||||||
} message: {
|
|
||||||
Text(viewModel.errorMessage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract model list section as independent view
|
|
||||||
@ViewBuilder
|
|
||||||
private var modelListSection: some View {
|
|
||||||
LazyVStack(spacing: 8) {
|
|
||||||
ForEach(Array(filteredModels.enumerated()), id: \.element.id) { index, model in
|
|
||||||
modelRowView(model: model, index: index)
|
|
||||||
|
|
||||||
if index < filteredModels.count - 1 {
|
|
||||||
Divider()
|
|
||||||
.padding(.horizontal, 16)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.padding(.vertical, 8)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract toolbar section as independent view
|
|
||||||
@ViewBuilder
|
|
||||||
private var toolbarSection: some View {
|
|
||||||
ToolbarView(
|
|
||||||
viewModel: viewModel, selectedSource: $selectedSource,
|
|
||||||
showSourceMenu: $showSourceMenu,
|
|
||||||
selectedTags: $selectedTags,
|
|
||||||
selectedCategories: $selectedCategories,
|
|
||||||
selectedVendors: $selectedVendors,
|
|
||||||
quickFilterTags: viewModel.quickFilterTags,
|
|
||||||
showFilterMenu: $showFilterMenu,
|
|
||||||
onSourceChange: handleSourceChange
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract single model row view as independent method
|
|
||||||
@ViewBuilder
|
|
||||||
private func modelRowView(model: TBModelInfo, index: Int) -> some View {
|
|
||||||
TBModelRowView(
|
|
||||||
model: model,
|
|
||||||
viewModel: viewModel,
|
|
||||||
downloadProgress: viewModel.downloadProgress[model.id] ?? 0,
|
|
||||||
isDownloading: viewModel.currentlyDownloading == model.id,
|
|
||||||
isOtherDownloading: isOtherDownloadingCheck(model: model)
|
|
||||||
) {
|
|
||||||
Task {
|
|
||||||
await viewModel.downloadModel(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.padding(.horizontal, 16)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract complex boolean logic as independent method
|
|
||||||
private func isOtherDownloadingCheck(model: TBModelInfo) -> Bool {
|
|
||||||
return viewModel.currentlyDownloading != nil && viewModel.currentlyDownloading != model.id
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract source change handling logic as independent method
|
|
||||||
private func handleSourceChange(_ source: ModelSource) {
|
|
||||||
ModelSourceManager.shared.updateSelectedSource(source)
|
|
||||||
selectedSource = source
|
|
||||||
Task {
|
|
||||||
await viewModel.fetchModels()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter models based on selected tags, categories and vendors
|
|
||||||
private var filteredModels: [TBModelInfo] {
|
|
||||||
let baseFiltered = viewModel.filteredModels
|
|
||||||
|
|
||||||
if selectedTags.isEmpty && selectedCategories.isEmpty && selectedVendors.isEmpty {
|
|
||||||
return baseFiltered
|
|
||||||
}
|
|
||||||
|
|
||||||
return baseFiltered.filter { model in
|
|
||||||
let tagMatch = checkTagMatch(model: model)
|
|
||||||
let categoryMatch = checkCategoryMatch(model: model)
|
|
||||||
let vendorMatch = checkVendorMatch(model: model)
|
|
||||||
|
|
||||||
return tagMatch && categoryMatch && vendorMatch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract tag matching logic as independent method
|
|
||||||
private func checkTagMatch(model: TBModelInfo) -> Bool {
|
|
||||||
return selectedTags.isEmpty || selectedTags.allSatisfy { selectedTag in
|
|
||||||
model.localizedTags.contains { tag in
|
|
||||||
tag.localizedCaseInsensitiveContains(selectedTag)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract category matching logic as independent method
|
|
||||||
private func checkCategoryMatch(model: TBModelInfo) -> Bool {
|
|
||||||
return selectedCategories.isEmpty || selectedCategories.allSatisfy { selectedCategory in
|
|
||||||
model.categories?.contains { category in
|
|
||||||
category.localizedCaseInsensitiveContains(selectedCategory)
|
|
||||||
} ?? false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract vendor matching logic as independent method
|
|
||||||
private func checkVendorMatch(model: TBModelInfo) -> Bool {
|
|
||||||
return selectedVendors.isEmpty || selectedVendors.contains { selectedVendor in
|
|
||||||
model.vendor?.localizedCaseInsensitiveContains(selectedVendor) ?? false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,94 +0,0 @@
|
||||||
//
|
|
||||||
// TBModelRowView.swift
|
|
||||||
// MNNLLMiOS
|
|
||||||
//
|
|
||||||
// Created by 游薪渝(揽清) on 2025/7/4.
|
|
||||||
//
|
|
||||||
|
|
||||||
import SwiftUI
|
|
||||||
|
|
||||||
struct TBModelRowView: View {
|
|
||||||
|
|
||||||
let model: TBModelInfo
|
|
||||||
@ObservedObject var viewModel: TBModelListViewModel
|
|
||||||
|
|
||||||
let downloadProgress: Double
|
|
||||||
let isDownloading: Bool
|
|
||||||
let isOtherDownloading: Bool
|
|
||||||
let onDownload: () -> Void
|
|
||||||
|
|
||||||
@State private var showDeleteAlert = false
|
|
||||||
|
|
||||||
// 预计算本地化标签,避免重复计算
|
|
||||||
private var localizedTags: [String] {
|
|
||||||
model.localizedTags
|
|
||||||
}
|
|
||||||
|
|
||||||
// 预计算格式化大小,避免重复计算
|
|
||||||
private var formattedSize: String {
|
|
||||||
model.formattedSize
|
|
||||||
}
|
|
||||||
|
|
||||||
var body: some View {
|
|
||||||
HStack(alignment: .top, spacing: 0) {
|
|
||||||
// 模型图标
|
|
||||||
ModelIconView(modelId: model.id)
|
|
||||||
.frame(width: 40, height: 40)
|
|
||||||
|
|
||||||
// 主要信息区域
|
|
||||||
VStack(alignment: .leading, spacing: 6) {
|
|
||||||
// 模型名称
|
|
||||||
Text(model.modelName)
|
|
||||||
.font(.headline)
|
|
||||||
.fontWeight(.semibold)
|
|
||||||
.lineLimit(1)
|
|
||||||
|
|
||||||
// 标签列表
|
|
||||||
if !localizedTags.isEmpty {
|
|
||||||
TagsView(tags: localizedTags)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.padding(.leading, 8)
|
|
||||||
|
|
||||||
Spacer()
|
|
||||||
|
|
||||||
ActionButtonsView(
|
|
||||||
model: model,
|
|
||||||
viewModel: viewModel,
|
|
||||||
downloadProgress: downloadProgress,
|
|
||||||
isDownloading: isDownloading,
|
|
||||||
isOtherDownloading: isOtherDownloading,
|
|
||||||
formattedSize: formattedSize,
|
|
||||||
onDownload: onDownload,
|
|
||||||
showDeleteAlert: $showDeleteAlert
|
|
||||||
)
|
|
||||||
}
|
|
||||||
.padding(.vertical, 8)
|
|
||||||
.contentShape(Rectangle()) // 确保整个区域都可以点击
|
|
||||||
.onTapGesture {
|
|
||||||
handleRowTap()
|
|
||||||
}
|
|
||||||
.alert("确认删除", isPresented: $showDeleteAlert) {
|
|
||||||
Button("删除", role: .destructive) {
|
|
||||||
Task {
|
|
||||||
await viewModel.deleteModel(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Button("取消", role: .cancel) { }
|
|
||||||
} message: {
|
|
||||||
Text("是否确认删除该模型?")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func handleRowTap() {
|
|
||||||
if model.isDownloaded {
|
|
||||||
return
|
|
||||||
} else if isDownloading {
|
|
||||||
Task {
|
|
||||||
await viewModel.cancelDownload()
|
|
||||||
}
|
|
||||||
} else if !isOtherDownloading {
|
|
||||||
onDownload()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,6 +19,9 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"Audio Message" : {
|
||||||
|
|
||||||
},
|
},
|
||||||
"Benchmark" : {
|
"Benchmark" : {
|
||||||
|
|
||||||
|
@ -202,9 +205,6 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"Last used: %@" : {
|
|
||||||
|
|
||||||
},
|
},
|
||||||
"Model Configuration" : {
|
"Model Configuration" : {
|
||||||
"localizations" : {
|
"localizations" : {
|
||||||
|
@ -365,11 +365,9 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"TB模型" : {
|
|
||||||
|
|
||||||
},
|
},
|
||||||
"This is a local large model application that requires certain performance from your device.\nIt is recommended to choose different model sizes based on your device's memory. \n\nThe model recommendations for iPhone are as follows:\n- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).\n- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).\n- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).\n\nChoosing a model that is too large may cause insufficient memory and crashes." : {
|
"This is a local large model application that requires certain performance from your device.\nIt is recommended to choose different model sizes based on your device's memory. \n\nThe model recommendations for iPhone are as follows:\n- For 8GB of RAM, models up to 8B are recommended (e.g., iPhone 16 Pro).\n- For 6GB of RAM, models up to 3B are recommended (e.g., iPhone 15 Pro).\n- For 4GB of RAM, models up to 1B or smaller are recommended (e.g., iPhone 13).\n\nChoosing a model that is too large may cause insufficient memory and crashes." : {
|
||||||
|
"extractionState" : "stale",
|
||||||
"localizations" : {
|
"localizations" : {
|
||||||
"en" : {
|
"en" : {
|
||||||
"stringUnit" : {
|
"stringUnit" : {
|
||||||
|
@ -426,6 +424,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"User Guide" : {
|
"User Guide" : {
|
||||||
|
"extractionState" : "stale",
|
||||||
"localizations" : {
|
"localizations" : {
|
||||||
"zh-Hans" : {
|
"zh-Hans" : {
|
||||||
"stringUnit" : {
|
"stringUnit" : {
|
||||||
|
@ -525,9 +524,6 @@
|
||||||
},
|
},
|
||||||
"语言" : {
|
"语言" : {
|
||||||
|
|
||||||
},
|
|
||||||
"错误" : {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"version" : "1.0"
|
"version" : "1.0"
|
||||||
|
|
Loading…
Reference in New Issue