mirror of https://github.com/alibaba/MNN.git
[feat] show model size
This commit is contained in:
parent
df333dd071
commit
bae99d5e68
|
@ -44,19 +44,18 @@ struct LocalModelRowView: View {
|
|||
}
|
||||
|
||||
HStack {
|
||||
HStack(alignment: .center, spacing: 2) {
|
||||
HStack(alignment: .bottom, spacing: 2) {
|
||||
Image(systemName: "folder")
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
.frame(width: 20, height: 20)
|
||||
|
||||
Text("3.4 GB")
|
||||
Text(model.formattedSize)
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
|
||||
|
||||
Spacer()
|
||||
|
||||
|
@ -70,4 +69,4 @@ struct LocalModelRowView: View {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -21,15 +21,130 @@ struct ModelInfo: Codable {
|
|||
var isDownloaded: Bool = false
|
||||
var lastUsedAt: Date?
|
||||
|
||||
var cachedSize: Int64? = nil
|
||||
|
||||
var localPath: String {
|
||||
return HubApi.shared.localRepoLocation(HubApi.Repo.init(id: modelId)).path
|
||||
}
|
||||
|
||||
var formattedSize: String {
|
||||
if isDownloaded {
|
||||
return formatLocalSize()
|
||||
} else if let cached = cachedSize {
|
||||
return formatBytes(cached)
|
||||
} else {
|
||||
return "计算中..."
|
||||
}
|
||||
}
|
||||
|
||||
func fetchRemoteSize() async -> Int64? {
|
||||
let modelScopeId = modelId.replacingOccurrences(of: "taobao-mnn", with: "MNN")
|
||||
|
||||
do {
|
||||
let files = try await fetchFileList(repoPath: modelScopeId, root: "", revision: "")
|
||||
let totalSize = try await calculateTotalSize(files: files, repoPath: modelScopeId)
|
||||
return totalSize
|
||||
} catch {
|
||||
print("Error fetching remote size for \(modelId): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func formatLocalSize() -> String {
|
||||
let path = localPath
|
||||
guard FileManager.default.fileExists(atPath: path) else { return "未知" }
|
||||
|
||||
do {
|
||||
let totalSize = try calculateDirectorySize(at: path)
|
||||
return formatBytes(totalSize)
|
||||
} catch {
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
private func calculateDirectorySize(at path: String) throws -> Int64 {
|
||||
let fileManager = FileManager.default
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
let enumerator = fileManager.enumerator(atPath: path)
|
||||
while let fileName = enumerator?.nextObject() as? String {
|
||||
let filePath = (path as NSString).appendingPathComponent(fileName)
|
||||
let attributes = try fileManager.attributesOfItem(atPath: filePath)
|
||||
if let fileSize = attributes[.size] as? Int64 {
|
||||
totalSize += fileSize
|
||||
}
|
||||
}
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
private func formatBytes(_ bytes: Int64) -> String {
|
||||
let formatter = ByteCountFormatter()
|
||||
formatter.allowedUnits = [.useGB]
|
||||
formatter.countStyle = .file
|
||||
return formatter.string(fromByteCount: bytes)
|
||||
}
|
||||
|
||||
// MARK: - 云端文件大小计算方法
|
||||
|
||||
private func fetchFileList(repoPath: String, root: String, revision: String) async throws -> [ModelFile] {
|
||||
let url = try buildURL(
|
||||
repoPath: repoPath,
|
||||
path: "/repo/files",
|
||||
queryItems: [
|
||||
URLQueryItem(name: "Root", value: root),
|
||||
URLQueryItem(name: "Revision", value: revision)
|
||||
]
|
||||
)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(from: url)
|
||||
try validateResponse(response)
|
||||
|
||||
let modelResponse = try JSONDecoder().decode(ModelResponse.self, from: data)
|
||||
return modelResponse.data.files
|
||||
}
|
||||
|
||||
private func calculateTotalSize(files: [ModelFile], repoPath: String) async throws -> Int64 {
|
||||
var totalSize: Int64 = 0
|
||||
|
||||
for file in files {
|
||||
if file.type == "tree" {
|
||||
let subFiles = try await fetchFileList(repoPath: repoPath, root: file.path, revision: "")
|
||||
totalSize += try await calculateTotalSize(files: subFiles, repoPath: repoPath)
|
||||
} else if file.type == "blob" {
|
||||
totalSize += Int64(file.size)
|
||||
}
|
||||
}
|
||||
|
||||
return totalSize
|
||||
}
|
||||
|
||||
private func buildURL(repoPath: String, path: String, queryItems: [URLQueryItem]) throws -> URL {
|
||||
var components = URLComponents()
|
||||
components.scheme = "https"
|
||||
components.host = "modelscope.cn"
|
||||
components.path = "/api/v1/models/\(repoPath)\(path)"
|
||||
components.queryItems = queryItems
|
||||
|
||||
guard let url = components.url else {
|
||||
throw ModelScopeError.invalidURL
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
private func validateResponse(_ response: URLResponse) throws {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200...299).contains(httpResponse.statusCode) else {
|
||||
throw ModelScopeError.invalidResponse
|
||||
}
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case modelId
|
||||
case tags
|
||||
case downloads
|
||||
case createdAt
|
||||
case cachedSize
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import SwiftUI
|
|||
|
||||
@MainActor
|
||||
class ModelListViewModel: ObservableObject {
|
||||
@Published private(set) var models: [ModelInfo] = []
|
||||
@Published var models: [ModelInfo] = []
|
||||
@Published private(set) var downloadProgress: [String: Double] = [:]
|
||||
@Published private(set) var currentlyDownloading: String?
|
||||
@Published var showError = false
|
||||
|
@ -77,12 +77,36 @@ class ModelListViewModel: ObservableObject {
|
|||
// Sort models
|
||||
sortModels(fetchedModels: &fetchedModels)
|
||||
|
||||
// 异步获取未下载模型的大小信息
|
||||
Task {
|
||||
await fetchModelSizes(for: fetchedModels)
|
||||
}
|
||||
|
||||
} catch {
|
||||
showError = true
|
||||
errorMessage = "Error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchModelSizes(for models: [ModelInfo]) async {
|
||||
await withTaskGroup(of: Void.self) { group in
|
||||
for (_, model) in models.enumerated() {
|
||||
if !model.isDownloaded && model.cachedSize == nil {
|
||||
group.addTask {
|
||||
if let size = await model.fetchRemoteSize() {
|
||||
await MainActor.run {
|
||||
// 查找当前模型在实际数组中的索引
|
||||
if let modelIndex = self.models.firstIndex(where: { $0.modelId == model.modelId }) {
|
||||
self.models[modelIndex].cachedSize = size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func recordModelUsage(modelId: String) {
|
||||
ModelStorageManager.shared.updateLastUsed(for: modelId)
|
||||
if let index = models.firstIndex(where: { $0.modelId == modelId }) {
|
||||
|
|
|
@ -36,7 +36,7 @@ struct ModelListView: View {
|
|||
.font(.system(size: 12, weight: .regular))
|
||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
||||
Image(systemName: "chevron.down")
|
||||
.frame(width: 12, height: 12, alignment: .leading)
|
||||
.frame(width: 10, height: 10, alignment: .leading)
|
||||
.scaledToFit()
|
||||
.foregroundColor(showOptions ? .primaryBlue : .black )
|
||||
}
|
||||
|
|
|
@ -100,10 +100,24 @@ struct ModelRowView: View {
|
|||
HStack(alignment: .bottom, spacing: 2) {
|
||||
Image(systemName: "folder")
|
||||
.font(.caption2)
|
||||
// Text(model.formattedSize)
|
||||
Text("3.6 GB")
|
||||
Text(model.formattedSize)
|
||||
.font(.caption2)
|
||||
.padding(.top, 4)
|
||||
.lineLimit(1)
|
||||
.minimumScaleFactor(0.8)
|
||||
.offset(y: 1)
|
||||
.onAppear {
|
||||
if !model.isDownloaded && model.cachedSize == nil {
|
||||
Task {
|
||||
if let size = await model.fetchRemoteSize() {
|
||||
await MainActor.run {
|
||||
if let index = viewModel.models.firstIndex(where: { $0.modelId == model.modelId }) {
|
||||
viewModel.models[index].cachedSize = size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.foregroundColor(.gray)
|
||||
}
|
||||
|
|
|
@ -9,12 +9,6 @@
|
|||
},
|
||||
"%lld" : {
|
||||
|
||||
},
|
||||
"3.4 GB" : {
|
||||
|
||||
},
|
||||
"3.6 GB" : {
|
||||
|
||||
},
|
||||
"Are you sure you want to delete this history?" : {
|
||||
"localizations" : {
|
||||
|
|
Loading…
Reference in New Issue