mirror of https://github.com/alibaba/MNN.git
Merge branch 'master' into feature/rvv-opt
This commit is contained in:
commit
4cf16d6761
|
|
@ -236,7 +236,7 @@ option(MNN_OPENGL "Enable OpenGL" OFF)
|
|||
option(MNN_VULKAN "Enable Vulkan" OFF)
|
||||
option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON)
|
||||
option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF)
|
||||
option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF)
|
||||
option(MNN_KLEIDIAI "Enable KLEIDIAI" ON)
|
||||
option(MNN_ONEDNN "Enable oneDNN" OFF)
|
||||
option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON)
|
||||
option(MNN_AVX512 "Enable AVX512" OFF)
|
||||
|
|
|
|||
|
|
@ -11,19 +11,13 @@
|
|||
|
||||
|
||||
## News 🔥
|
||||
- [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) !
|
||||
- [2025/06/11] New App MNN-TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN-TaoAvatar](./apps/Android/MnnTaoAvatar/README.md)
|
||||
<p align="center">
|
||||
<img width="20%" alt="Icon" src="https://meta.alicdn.com/data/mnn/avatar/avatar_demo.gif" style="margin: 0 10px;">
|
||||
</p>
|
||||
|
||||
- [2025/05/30] MNN Chat app support DeepSeek-R1-0528-Qwen3,Qwen3-30B-A3B, SmoVLM and FastVLM [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases).
|
||||
- [2025/05/12] android app support qwen2.5 omni 3b and 7b [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases).
|
||||
<p align="center">
|
||||
<img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_home_new.jpg" style="margin: 0 10px;">
|
||||
<img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_sound_new.jpg" style="margin: 0 10px;">
|
||||
<img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_image_new.jpg" style="margin: 0 10px;">
|
||||
</p>
|
||||
|
||||
|
||||
<details>
|
||||
<summary> History News </summary>
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
[Download](#releases) [下载](./README_CN.md#releases)
|
||||
|
||||
[GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release)
|
||||
|
||||
[iOS App](../../iOS/MNNLLMChat/README.md)
|
||||
|
||||
## Introduction
|
||||
|
|
@ -59,6 +61,14 @@ This is our full multimodal language model (LLM) Android app
|
|||
```
|
||||
|
||||
# Releases
|
||||
## Version 0.6.8
|
||||
+ Click here to [download](https://meta.alicdn.com/data/mnn/mnn_chat_0_6_8.apk)
|
||||
+ add new models: SmolLM3-3B、gemma-3-1b
|
||||
+ support penalty sampler in mixed sampler mode.
|
||||
+ can switch models in chat screen.
|
||||
+ can update models when the remote models changed.
|
||||
+ fix download source for huggingface.
|
||||
+ Support Realtime voice call with ASR and TTS
|
||||
## Version 0.5.1.2
|
||||
+ Click here to [download](https://meta.alicdn.com/data/mnn/mnn_chat_0_5_1_2.apk)
|
||||
+ fix huggingface download error
|
||||
|
|
|
|||
|
|
@ -53,6 +53,13 @@
|
|||
```
|
||||
|
||||
# Releases
|
||||
+ 点击这里 [下载](https://meta.alicdn.com/data/mnn/mnn_chat_0_6_8.apk)
|
||||
+ 新增模型支持 :现已支持 SmolLM3-3B 和 gemma-3-1b 模型。
|
||||
+ 采样器功能增强 :在混合采样器模式中新增对 penalty sampler 的支持,提升生成质量与多样性。
|
||||
+ 模型切换优化 :在聊天界面中支持 实时切换模型 ,提升使用灵活性。
|
||||
+ 模型热更新支持 :当远程模型发生变化时,无需删除原有模型即可完成更新,避免重复下载。
|
||||
+ 下载源优化 :优化了 HuggingFace 模型的下载源,提升下载速度与稳定性。
|
||||
+ 新增语音通话功能 :支持 实时语音对话 ,集成语音识别(ASR)与语音合成(TTS)功能,带来更丰富的交互体验。
|
||||
|
||||
## Version 0.5.1.2
|
||||
+ 点击这里 [下载](https://meta.alicdn.com/data/mnn/mnn_chat_0_5_1_2.apk)
|
||||
|
|
|
|||
|
|
@ -1 +1,5 @@
|
|||
/build
|
||||
/build
|
||||
|
||||
# Native libraries (downloaded at build time)
|
||||
src/main/jniLibs/arm64-v8a/libsherpa-mnn-jni.so
|
||||
src/main/jniLibs/arm64-v8a/libsherpa-mnn-jni.zip
|
||||
|
|
@ -5,6 +5,49 @@ plugins {
|
|||
|
||||
}
|
||||
|
||||
task downloadAndUnzipNativeLibs {
|
||||
group = 'Pre-build'
|
||||
description = 'Downloads and unzips libsherpa-mnn-jni native library from CDN.'
|
||||
def nativeLibsUrl = 'https://meta.alicdn.com/data/mnn/libs/libsherpa-mnn-jni-16k.zip'
|
||||
def zipFileName = 'libsherpa-mnn-jni-16k.zip'
|
||||
def outputDir = file('src/main/jniLibs/arm64-v8a')
|
||||
def downloadedZip = new File(project.buildDir, zipFileName)
|
||||
def checkFile = new File(outputDir, 'libsherpa-mnn-jni.so')
|
||||
|
||||
inputs.property('url', nativeLibsUrl)
|
||||
outputs.file(checkFile)
|
||||
|
||||
doLast {
|
||||
println "-> Executing downloadAndUnzipNativeLibs task..."
|
||||
println " Downloading from ${nativeLibsUrl}"
|
||||
|
||||
// Create output directory if it doesn't exist
|
||||
outputDir.mkdirs()
|
||||
|
||||
ant.get(src: nativeLibsUrl, dest: downloadedZip)
|
||||
|
||||
if (!downloadedZip.exists()) {
|
||||
throw new GradleException("Download failed: ${downloadedZip} not found.")
|
||||
}
|
||||
println " Download complete."
|
||||
println " Unzipping ${downloadedZip.name} to ${outputDir}..."
|
||||
|
||||
copy {
|
||||
from(zipTree(downloadedZip))
|
||||
into(outputDir)
|
||||
}
|
||||
println " Unzip complete."
|
||||
downloadedZip.delete()
|
||||
}
|
||||
|
||||
onlyIf {
|
||||
println "-> Checking if native libs exist... [Exists: ${checkFile.exists()}]"
|
||||
return !checkFile.exists()
|
||||
}
|
||||
}
|
||||
|
||||
preBuild.dependsOn downloadAndUnzipNativeLibs
|
||||
|
||||
android {
|
||||
namespace 'com.alibaba.mnnllm.android'
|
||||
compileSdk 35
|
||||
|
|
@ -16,8 +59,8 @@ android {
|
|||
applicationId "com.alibaba.mnnllm.android"
|
||||
minSdk 26
|
||||
targetSdk 35
|
||||
versionCode 606
|
||||
versionName "0.6.6"
|
||||
versionCode 608
|
||||
versionName "0.6.8"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
externalNativeBuild {
|
||||
|
|
@ -56,18 +99,22 @@ android {
|
|||
|
||||
signingConfigs {
|
||||
release {
|
||||
storeFile file(System.getenv("KEYSTORE_FILE"))
|
||||
storePassword System.getenv("KEYSTORE_PASSWORD")
|
||||
keyAlias System.getenv("KEY_ALIAS")
|
||||
keyPassword System.getenv("KEY_PASSWORD")
|
||||
def keystoreFile = System.getenv("KEYSTORE_FILE")
|
||||
if (keystoreFile) {
|
||||
storeFile file(keystoreFile)
|
||||
storePassword System.getenv("KEYSTORE_PASSWORD")
|
||||
keyAlias System.getenv("KEY_ALIAS")
|
||||
keyPassword System.getenv("KEY_PASSWORD")
|
||||
}
|
||||
}
|
||||
}
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.release
|
||||
if (System.getenv("KEYSTORE_FILE")) {
|
||||
signingConfig signingConfigs.release
|
||||
}
|
||||
applicationIdSuffix ".release"
|
||||
}
|
||||
debug {
|
||||
versionNameSuffix ".gp"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
<resources>
|
||||
<string name="app_name">MNN Chat-dev</string>
|
||||
<string name="app_name">MNN Chat</string>
|
||||
</resources>
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
<resources>
|
||||
<string name="app_name">MNN Chat-dev</string>
|
||||
<string name="app_name">MNN Chat</string>
|
||||
</resources>
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,97 @@
|
|||
// Created by ruoyi.sjd on 2025/1/27.
|
||||
// Copyright (c) 2024 Alibaba Group Holding Limited All rights reserved.
|
||||
|
||||
package com.alibaba.mls.api.download
|
||||
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.CoroutineDispatcher
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* Centralized coroutine management for download operations
|
||||
* Provides configurable scope and dispatcher to avoid threading issues
|
||||
*
|
||||
* Usage:
|
||||
* ```
|
||||
* // Configure in Application.onCreate()
|
||||
* DownloadCoroutineManager.configureDispatcher(Dispatchers.IO) // or Dispatchers.Default
|
||||
*
|
||||
* // In downloaders
|
||||
* DownloadCoroutineManager.launchDownload {
|
||||
* // download logic
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
object DownloadCoroutineManager {
|
||||
private const val TAG = "DownloadCoroutineManager"
|
||||
|
||||
/**
|
||||
* Configurable dispatcher for download operations
|
||||
* Default to Dispatchers.Default to avoid IO dispatcher blocking issues
|
||||
*/
|
||||
var downloadDispatcher: CoroutineDispatcher = Dispatchers.Default
|
||||
private set
|
||||
|
||||
/**
|
||||
* Coroutine scope for download operations
|
||||
*/
|
||||
private val _downloadScope by lazy {
|
||||
CoroutineScope(downloadDispatcher + SupervisorJob())
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure the dispatcher used for download operations
|
||||
* @param dispatcher The coroutine dispatcher to use
|
||||
*/
|
||||
fun configureDispatcher(dispatcher: CoroutineDispatcher) {
|
||||
Log.d(TAG, "Configuring download dispatcher to: $dispatcher")
|
||||
downloadDispatcher = dispatcher
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize with IO dispatcher (call this when system is stable)
|
||||
*/
|
||||
fun initializeWithIO() {
|
||||
configureDispatcher(Dispatchers.IO)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize with Default dispatcher (call this when IO dispatcher has issues)
|
||||
*/
|
||||
fun initializeWithDefault() {
|
||||
configureDispatcher(Dispatchers.Default)
|
||||
}
|
||||
|
||||
/**
|
||||
* Launch a download coroutine with proper error handling
|
||||
* @param block The suspend function to execute
|
||||
*/
|
||||
fun launchDownload(block: suspend CoroutineScope.() -> Unit) {
|
||||
_downloadScope.launch(downloadDispatcher, block = block)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a scope with download dispatcher for manual coroutine management
|
||||
*/
|
||||
fun getDownloadScope(): CoroutineScope {
|
||||
return _downloadScope
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset to default configuration (for testing or recovery)
|
||||
*/
|
||||
fun resetToDefault() {
|
||||
Log.d(TAG, "Resetting to default dispatcher (Dispatchers.Default)")
|
||||
downloadDispatcher = Dispatchers.Default
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current dispatcher info for debugging
|
||||
*/
|
||||
fun getCurrentDispatcherInfo(): String {
|
||||
return "Current download dispatcher: $downloadDispatcher"
|
||||
}
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ data class DownloadInfo(
|
|||
var totalSize: Long = 0,
|
||||
var speedInfo: String = "",
|
||||
var errorMessage: String? = null,
|
||||
var errorException: Exception? = null,
|
||||
var lastLogTime: Long = 0,
|
||||
var lastProgressUpdateTime: Long = 0,
|
||||
var progressStage: String = "",
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ object DownloadPersistentData {
|
|||
const val METADATA_KEY: String = "meta_data"
|
||||
const val SIZE_TOTAL_KEY: String = "size_total"
|
||||
const val SIZE_SAVED_KEY: String = "size_saved"
|
||||
const val SIZE_MARKET_TOTAL_KEY: String = "size_market_total"
|
||||
const val DOWNLOADED_TIME_KEY: String = "downloaded_time"
|
||||
|
||||
// Create preference keys with modelId
|
||||
|
|
@ -38,6 +39,9 @@ object DownloadPersistentData {
|
|||
private fun createMetaDataKey(modelId: String): Preferences.Key<String> =
|
||||
stringPreferencesKey("${METADATA_KEY}_$modelId")
|
||||
|
||||
private fun createSizeMarketTotalKey(modelId: String): Preferences.Key<Long> =
|
||||
longPreferencesKey("${SIZE_MARKET_TOTAL_KEY}_$modelId")
|
||||
|
||||
private fun createDownloadedTimeKey(modelId: String): Preferences.Key<Long> =
|
||||
longPreferencesKey("${DOWNLOADED_TIME_KEY}_$modelId")
|
||||
|
||||
|
|
@ -204,6 +208,35 @@ object DownloadPersistentData {
|
|||
|
||||
return dataStoreValue ?: 0L
|
||||
}
|
||||
|
||||
fun saveMarketSizeTotal(context: Context, modelId: String, total: Long) {
|
||||
runBlocking { saveMarketSizeTotalSuspend(context, modelId, total) }
|
||||
}
|
||||
|
||||
suspend fun saveMarketSizeTotalSuspend(context: Context, modelId: String, total: Long) {
|
||||
val normalizedModelId = ModelUtils.safeModelId(modelId)
|
||||
val key = createSizeMarketTotalKey(normalizedModelId)
|
||||
|
||||
context.downloadDataStore.edit { preferences ->
|
||||
preferences[key] = total
|
||||
}
|
||||
}
|
||||
|
||||
fun getMarketSizeTotal(context: Context, modelId: String): Long {
|
||||
return runBlocking { getMarketSizeTotalSuspend(context, modelId) }
|
||||
}
|
||||
|
||||
suspend fun getMarketSizeTotalSuspend(context: Context, modelId: String): Long {
|
||||
val normalizedModelId = ModelUtils.safeModelId(modelId)
|
||||
val key = createSizeMarketTotalKey(normalizedModelId)
|
||||
|
||||
// Try to read from DataStore
|
||||
val dataStoreValue = context.downloadDataStore.data
|
||||
.map { preferences -> preferences[key] }
|
||||
.first()
|
||||
|
||||
return dataStoreValue ?: 0L
|
||||
}
|
||||
|
||||
// Private migration helpers
|
||||
private fun deleteSharedPrefsFileIfEmpty(context: Context, preferencesName: String) {
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ import com.alibaba.mnnllm.android.utils.CurrentActivityTracker
|
|||
import com.alibaba.mnnllm.android.utils.FileUtils
|
||||
import com.alibaba.mnnllm.android.utils.MmapUtils.clearMmapCache
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.MainScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
|
@ -42,7 +40,7 @@ class LoggingDownloadListener : DownloadListener {
|
|||
}
|
||||
|
||||
override fun onDownloadProgress(modelId: String, downloadInfo: DownloadInfo) {
|
||||
Log.v(ModelDownloadManager.TAG, "onDownloadProgress: $modelId, progress: ${downloadInfo.progress}, state: ${downloadInfo.downloadState}, speed: ${downloadInfo.speedInfo}")
|
||||
Log.v(ModelDownloadManager.TAG, "onDownloadProgress: $modelId, progress: ${downloadInfo.progress}, state: ${downloadInfo.downloadState}, speed: ${downloadInfo.speedInfo} stage: ${downloadInfo.progressStage}")
|
||||
}
|
||||
|
||||
override fun onDownloadFinished(modelId: String, path: String) {
|
||||
|
|
@ -253,7 +251,6 @@ class ModelDownloadManager private constructor(context: Context) {
|
|||
fun getDownloadInfo(modelId: String): DownloadInfo {
|
||||
Log.d(TAG, "getDownloadInfo: $modelId totalSize: ${DownloadPersistentData.getDownloadSizeTotal(ApplicationProvider.get(), modelId)}" +
|
||||
" progress: ${getRealDownloadSize(modelId)}")
|
||||
val source = ModelUtils.getSource(modelId)!!
|
||||
if (!downloadInfoMap.containsKey(modelId)) {
|
||||
val downloadInfo = DownloadInfo()
|
||||
if (getDownloadedFile(modelId) != null) {
|
||||
|
|
@ -277,7 +274,11 @@ class ModelDownloadManager private constructor(context: Context) {
|
|||
}
|
||||
downloadInfoMap[modelId] = downloadInfo
|
||||
if (downloadInfo.totalSize < 100) {
|
||||
getRepoSize(modelId, source)
|
||||
val marketSize = DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
|
||||
Log.d(TAG, "getMarketSize for ${modelId} size: $marketSize")
|
||||
if (marketSize > 0) {
|
||||
downloadInfo.totalSize = marketSize
|
||||
}
|
||||
}
|
||||
}
|
||||
val downloadInfo = downloadInfoMap[modelId]!!
|
||||
|
|
@ -305,26 +306,6 @@ class ModelDownloadManager private constructor(context: Context) {
|
|||
return savedSize
|
||||
}
|
||||
|
||||
private fun getRepoSize(
|
||||
modelId: String,
|
||||
source: Any,
|
||||
modelName: String = modelId
|
||||
) {
|
||||
MainScope().launch {
|
||||
val downloader = when (source) {
|
||||
is String -> getDownloaderForSource(source)
|
||||
is ModelSources.ModelSourceType -> getDownloaderForSource(source)
|
||||
else -> throw IllegalArgumentException("Unsupported source type: ${source::class.java.name}")
|
||||
}
|
||||
val repoSize = downloader.getRepoSize(modelId)
|
||||
if (repoSize > 0 && DownloadPersistentData.getDownloadSizeTotal(ApplicationProvider.get(), modelName) <= 0L) {
|
||||
DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelName, repoSize)
|
||||
downloadInfoMap[modelName]?.totalSize = repoSize
|
||||
listeners.forEach { it.onDownloadTotalSize(modelName, repoSize) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun setDownloadFinished(modelId: String, path: String) {
|
||||
val downloadInfo = downloadInfoMap[modelId] ?: return
|
||||
downloadInfo.downloadState = DownloadState.COMPLETED
|
||||
|
|
@ -422,6 +403,7 @@ class ModelDownloadManager private constructor(context: Context) {
|
|||
val info = downloadInfoMap.getOrPut(modelId) { DownloadInfo() }
|
||||
info.downloadState = DownloadState.FAILED
|
||||
info.errorMessage = e.message
|
||||
info.errorException = e
|
||||
listeners.forEach {
|
||||
Log.d(TAG, "[setDownloadFailed] Notifying listener: ${it.javaClass.simpleName}")
|
||||
it.onDownloadFailed(modelId, e)
|
||||
|
|
@ -470,9 +452,6 @@ class ModelDownloadManager private constructor(context: Context) {
|
|||
savedSize: Long,
|
||||
totalSize: Long
|
||||
) {
|
||||
if (totalSize <= 0) {
|
||||
return
|
||||
}
|
||||
val downloadInfo = downloadInfoMap.getOrPut(modelId) { DownloadInfo() }
|
||||
val currentTime = System.currentTimeMillis()
|
||||
if (stage == downloadInfo.progressStage && currentTime - downloadInfo.lastProgressUpdateTime < 1000) {
|
||||
|
|
@ -486,11 +465,13 @@ class ModelDownloadManager private constructor(context: Context) {
|
|||
if (downloadInfo.savedSize <= 0) {
|
||||
downloadInfo.savedSize = savedSize
|
||||
}
|
||||
downloadInfo.progress = savedSize.toDouble() / totalSize
|
||||
DownloadPersistentData.saveDownloadSizeSaved(ApplicationProvider.get(), modelId, savedSize)
|
||||
DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelId, totalSize)
|
||||
calculateDownloadSpeed(downloadInfo, savedSize)
|
||||
Log.v(TAG, "[updateDownloadingProgress] Notifying ${listeners.size} listeners for $modelId")
|
||||
if (totalSize > 0) {
|
||||
downloadInfo.progress = savedSize.toDouble() / totalSize
|
||||
DownloadPersistentData.saveDownloadSizeSaved(ApplicationProvider.get(), modelId, savedSize)
|
||||
DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelId, totalSize)
|
||||
calculateDownloadSpeed(downloadInfo, savedSize)
|
||||
}
|
||||
Log.v(TAG, "[updateDownloadingProgress] Notifying ${listeners.size} listeners for $modelId stage: $stage")
|
||||
listeners.forEach { it.onDownloadProgress(modelId, downloadInfo) }
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ import com.alibaba.mls.api.hf.HfFileMetadata
|
|||
import com.alibaba.mls.api.hf.HfRepoInfo
|
||||
import com.alibaba.mnnllm.android.chat.model.ChatDataManager
|
||||
import com.alibaba.mnnllm.android.utils.TimeUtils
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import com.alibaba.mls.api.download.DownloadCoroutineManager
|
||||
import kotlinx.coroutines.withContext
|
||||
import okhttp3.OkHttpClient
|
||||
import java.io.File
|
||||
|
|
@ -60,10 +59,11 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
* Unified method to fetch repo information from HuggingFace API
|
||||
* @param modelId the model ID to fetch info for
|
||||
* @param calculateSize whether to calculate repo size (requires additional network requests)
|
||||
* @return HfRepoInfo object or null if failed
|
||||
* @return HfRepoInfo object
|
||||
* @throws FileDownloadException if failed to fetch repo info
|
||||
*/
|
||||
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): HfRepoInfo? {
|
||||
return withContext(Dispatchers.IO) {
|
||||
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): HfRepoInfo {
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val hfModelId = hfModelId(modelId)
|
||||
val response = getHfApiClient().apiService.getRepoInfo(hfModelId, "main")?.execute()
|
||||
|
|
@ -79,11 +79,16 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
callback?.onRepoInfo(modelId, lastModified, repoSize)
|
||||
repoInfo
|
||||
} else {
|
||||
null
|
||||
val errorMsg = if (response?.isSuccessful == false) {
|
||||
"API request failed with code ${response.code()}: ${response.message()}"
|
||||
} else {
|
||||
"API response was null or empty"
|
||||
}
|
||||
throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg")
|
||||
}
|
||||
}.getOrElse { exception ->
|
||||
Log.e(TAG, "Failed to fetch repo info for $modelId", exception)
|
||||
null
|
||||
throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -92,7 +97,7 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
* Calculate total size of all files in the repo
|
||||
*/
|
||||
private suspend fun calculateRepoSize(hfRepoInfo: HfRepoInfo): Long {
|
||||
return withContext(Dispatchers.IO) {
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val metaList = requestMetaDataList(hfRepoInfo)
|
||||
metaList.sumOf { it?.size ?: 0L }
|
||||
|
|
@ -101,20 +106,22 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
}
|
||||
|
||||
override fun download(modelId: String) {
|
||||
executor!!.submit {
|
||||
runBlocking {
|
||||
DownloadCoroutineManager.launchDownload {
|
||||
try {
|
||||
val repoInfo = fetchRepoInfo(modelId)
|
||||
if (repoInfo != null) {
|
||||
downloadHfRepo(repoInfo)
|
||||
} else {
|
||||
callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId"))
|
||||
}
|
||||
downloadHfRepo(repoInfo)
|
||||
} catch (e: FileDownloadException) {
|
||||
callback?.onDownloadFailed(modelId, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun checkUpdate(modelId: String) {
|
||||
fetchRepoInfo(modelId)
|
||||
try {
|
||||
fetchRepoInfo(modelId)
|
||||
} catch (e: FileDownloadException) {
|
||||
Log.e(TAG, "Failed to check update for $modelId", e)
|
||||
}
|
||||
}
|
||||
|
||||
fun downloadHfRepo(hfRepoInfo: HfRepoInfo) {
|
||||
|
|
@ -127,17 +134,23 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
}
|
||||
|
||||
override suspend fun getRepoSize(modelId: String): Long {
|
||||
return withContext(Dispatchers.IO) {
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
if (repoInfo != null) {
|
||||
// Size was already calculated in fetchRepoInfo, but we can also calculate it directly
|
||||
calculateRepoSize(repoInfo)
|
||||
} else {
|
||||
0L
|
||||
}
|
||||
// Size was already calculated in fetchRepoInfo, but we can also calculate it directly
|
||||
calculateRepoSize(repoInfo)
|
||||
}.getOrElse { exception ->
|
||||
Log.e(TAG, "Failed to get repo size for $modelId", exception)
|
||||
// Try to get file_size from saved market data as fallback
|
||||
try {
|
||||
val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
|
||||
if (marketSize > 0) {
|
||||
Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize")
|
||||
return@withContext marketSize
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to get saved market size for $modelId", e)
|
||||
}
|
||||
0L
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ import com.alibaba.mls.api.ml.FileInfo
|
|||
import com.alibaba.mls.api.ml.MlApiClient
|
||||
import com.alibaba.mls.api.ml.MlRepoInfo
|
||||
import com.alibaba.mnnllm.android.model.ModelUtils
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import com.alibaba.mls.api.download.DownloadCoroutineManager
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import com.alibaba.mls.api.ml.MlRepoData
|
||||
|
|
@ -43,15 +42,16 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
* Unified method to fetch repo information from Modelers API
|
||||
* @param modelId the model ID to fetch info for
|
||||
* @param calculateSize whether to calculate repo size (requires additional network requests)
|
||||
* @return MlRepoInfo object or null if failed
|
||||
* @return MlRepoInfo object
|
||||
* @throws FileDownloadException if failed to fetch repo info
|
||||
*/
|
||||
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MlRepoInfo? {
|
||||
return withContext(Dispatchers.IO) {
|
||||
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MlRepoInfo {
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val modelersId = ModelUtils.getRepositoryPath(modelId)
|
||||
val split = modelersId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||
if (split.size != 2) {
|
||||
return@runCatching null
|
||||
throw FileDownloadException("Invalid model ID format for $modelId, expected format: owner/repo")
|
||||
}
|
||||
|
||||
val response = mlApiClient.apiService.getModelFiles(split[0], split[1], "").execute()
|
||||
|
|
@ -84,24 +84,33 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
callback?.onRepoInfo(modelId, lastModified, repoSize)
|
||||
repoInfo
|
||||
} else {
|
||||
null
|
||||
val errorMsg = if (!response.isSuccessful) {
|
||||
"API request failed with code ${response.code()}: ${response.message()}"
|
||||
} else {
|
||||
"API response was null or empty"
|
||||
}
|
||||
throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg")
|
||||
}
|
||||
}.getOrElse { exception ->
|
||||
Log.e(TAG, "Failed to fetch repo info for $modelId", exception)
|
||||
null
|
||||
throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun download(modelId: String) {
|
||||
Log.d(TAG, "start download ${modelId}")
|
||||
DownloadExecutor.executeScope.launch {
|
||||
DownloadCoroutineManager.launchDownload {
|
||||
downloadRepo(modelId)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun checkUpdate(modelId: String) {
|
||||
fetchRepoInfo(modelId, false)
|
||||
try {
|
||||
fetchRepoInfo(modelId, false)
|
||||
} catch (e: FileDownloadException) {
|
||||
Log.e(TAG, "Failed to check update for $modelId", e)
|
||||
}
|
||||
}
|
||||
|
||||
override fun getDownloadPath(modelId: String): File {
|
||||
|
|
@ -125,16 +134,22 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
}
|
||||
|
||||
override suspend fun getRepoSize(modelId: String): Long {
|
||||
return withContext(Dispatchers.IO) {
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
if (repoInfo != null) {
|
||||
repoInfo.data.tree.filter { it.type != "dir" }.sumOf { it.size }
|
||||
} else {
|
||||
0L
|
||||
}
|
||||
repoInfo.data.tree.filter { it.type != "dir" }.sumOf { it.size }
|
||||
}.getOrElse { exception ->
|
||||
Log.e(TAG, "Failed to get repo size for $modelId", exception)
|
||||
// Try to get file_size from saved market data as fallback
|
||||
try {
|
||||
val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
|
||||
if (marketSize > 0) {
|
||||
Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize")
|
||||
return@withContext marketSize
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to get saved market size for $modelId", e)
|
||||
}
|
||||
0L
|
||||
}
|
||||
}
|
||||
|
|
@ -149,15 +164,16 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
return null
|
||||
}
|
||||
|
||||
val modelInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
if (modelInfo != null) {
|
||||
try {
|
||||
val modelInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
callback?.onDownloadTaskAdded()
|
||||
downloadMlRepoInner(modelId, modelersId, modelInfo)
|
||||
callback?.onDownloadTaskRemoved()
|
||||
} else {
|
||||
callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId"))
|
||||
return modelInfo
|
||||
} catch (e: FileDownloadException) {
|
||||
callback?.onDownloadFailed(modelId, e)
|
||||
return null
|
||||
}
|
||||
return modelInfo
|
||||
}
|
||||
|
||||
private fun getAllFiles(owner: String, repo: String, path: String, allFiles: MutableList<FileInfo>) {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import android.util.Log
|
|||
import com.alibaba.mls.api.ApplicationProvider
|
||||
import com.alibaba.mls.api.FileDownloadException
|
||||
import com.alibaba.mls.api.hf.HfFileMetadata
|
||||
import com.alibaba.mls.api.download.DownloadExecutor.Companion.executor
|
||||
import com.alibaba.mls.api.download.DownloadFileUtils.createSymlink
|
||||
import com.alibaba.mls.api.download.DownloadFileUtils.deleteDirectoryRecursively
|
||||
import com.alibaba.mls.api.download.DownloadFileUtils.getLastFileName
|
||||
|
|
@ -22,9 +21,10 @@ import com.alibaba.mls.api.ms.MsApiClient
|
|||
import com.alibaba.mls.api.ms.MsRepoInfo
|
||||
import com.alibaba.mnnllm.android.chat.model.ChatDataManager
|
||||
import com.alibaba.mnnllm.android.model.ModelUtils
|
||||
import com.alibaba.mnnllm.android.modelmarket.ModelMarketItem
|
||||
import com.alibaba.mnnllm.android.modelmarket.ModelRepository
|
||||
import com.alibaba.mnnllm.android.utils.TimeUtils
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import com.alibaba.mls.api.download.DownloadCoroutineManager
|
||||
import kotlinx.coroutines.withContext
|
||||
import retrofit2.Call
|
||||
import retrofit2.Callback
|
||||
|
|
@ -44,15 +44,18 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
* Unified method to fetch repo information from ModelScope API
|
||||
* @param modelId the model ID to fetch info for
|
||||
* @param calculateSize whether to calculate repo size (not needed for ModelScope as size is included)
|
||||
* @return MsRepoInfo object or null if failed
|
||||
* @return MsRepoInfo object
|
||||
* @throws FileDownloadException if failed to fetch repo info
|
||||
*/
|
||||
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MsRepoInfo? {
|
||||
return withContext(Dispatchers.IO) {
|
||||
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MsRepoInfo {
|
||||
Log.d(TAG, "fetchRepoInfo called for modelId: $modelId, calculateSize: $calculateSize")
|
||||
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val msModelId = ModelUtils.getRepositoryPath(modelId)
|
||||
val split = msModelId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||
if (split.size != 2) {
|
||||
return@runCatching null
|
||||
throw FileDownloadException("Invalid model ID format for $modelId, expected format: owner/repo")
|
||||
}
|
||||
|
||||
val response = msApiClient.apiService.getModelFiles(split[0], split[1]).execute()
|
||||
|
|
@ -65,25 +68,34 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
callback?.onRepoInfo(modelId, lastModified, repoSize)
|
||||
repoInfo
|
||||
} else {
|
||||
null
|
||||
val errorMsg = if (!response.isSuccessful) {
|
||||
"API request failed with code ${response.code()}: ${response.message()}"
|
||||
} else {
|
||||
"API response was null or empty"
|
||||
}
|
||||
throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg")
|
||||
}
|
||||
}.getOrElse { exception ->
|
||||
Log.e(TAG, "Failed to fetch repo info for $modelId", exception)
|
||||
null
|
||||
throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun download(modelId: String) {
|
||||
executor!!.submit {
|
||||
kotlinx.coroutines.runBlocking {
|
||||
downloadMsRepo(modelId)
|
||||
}
|
||||
Log.d(TAG, "MsModelDownloader download: $modelId")
|
||||
|
||||
DownloadCoroutineManager.launchDownload {
|
||||
downloadMsRepo(modelId)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun checkUpdate(modelId: String) {
|
||||
fetchRepoInfo(modelId)
|
||||
try {
|
||||
fetchRepoInfo(modelId)
|
||||
} catch (e: FileDownloadException) {
|
||||
Log.e(TAG, "Failed to check update for $modelId", e)
|
||||
}
|
||||
}
|
||||
|
||||
override fun getDownloadPath(modelId: String): File {
|
||||
|
|
@ -107,16 +119,22 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
}
|
||||
|
||||
override suspend fun getRepoSize(modelId: String): Long {
|
||||
return withContext(Dispatchers.IO) {
|
||||
return withContext(DownloadCoroutineManager.downloadDispatcher) {
|
||||
runCatching {
|
||||
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
if (repoInfo != null) {
|
||||
repoInfo.Data?.Files?.filter { it.Type != "tree" }?.sumOf { it.Size } ?: 0L
|
||||
} else {
|
||||
0L
|
||||
}
|
||||
repoInfo.Data?.Files?.filter { it.Type != "tree" }?.sumOf { it.Size } ?: 0L
|
||||
}.getOrElse { exception ->
|
||||
Log.e(TAG, "Failed to get repo size for $modelId", exception)
|
||||
// Try to get file_size from saved market data as fallback
|
||||
try {
|
||||
val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
|
||||
if (marketSize > 0) {
|
||||
Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize")
|
||||
return@withContext marketSize
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to get saved market size for $modelId", e)
|
||||
}
|
||||
0L
|
||||
}
|
||||
}
|
||||
|
|
@ -124,20 +142,20 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
|
|||
|
||||
private suspend fun downloadMsRepo(modelId: String) {
|
||||
val modelScopeId = ModelUtils.getRepositoryPath(modelId)
|
||||
Log.d(TAG, "MsModelDownloader downloadMsRepo: $modelId modelScopeId : $modelScopeId")
|
||||
val split = modelScopeId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
|
||||
if (split.size != 2) {
|
||||
callback?.onDownloadFailed(modelId, FileDownloadException("getRepoInfoFailed modelId format error: $modelId"))
|
||||
return
|
||||
}
|
||||
|
||||
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
if (repoInfo != null) {
|
||||
Log.d(TAG, "downloadMsRepoInner executor")
|
||||
try {
|
||||
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
|
||||
Log.d(TAG, "downloadMsRepo repoInfo: $repoInfo")
|
||||
callback?.onDownloadTaskAdded()
|
||||
downloadMsRepoInner(modelId, modelScopeId, repoInfo)
|
||||
callback?.onDownloadTaskRemoved()
|
||||
} else {
|
||||
callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId"))
|
||||
} catch (e: FileDownloadException) {
|
||||
callback?.onDownloadFailed(modelId, e)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class BenchmarkStateMachine {
|
|||
fun isValidTransition(from: BenchmarkState, to: BenchmarkState): Boolean {
|
||||
val validTransitions = when (from) {
|
||||
BenchmarkState.IDLE -> listOf(BenchmarkState.LOADING_MODELS)
|
||||
BenchmarkState.LOADING_MODELS -> listOf(BenchmarkState.READY, BenchmarkState.ERROR_MODEL_NOT_FOUND)
|
||||
BenchmarkState.LOADING_MODELS -> listOf(BenchmarkState.READY, BenchmarkState.ERROR_MODEL_NOT_FOUND, BenchmarkState.ERROR)
|
||||
BenchmarkState.READY -> listOf(BenchmarkState.INITIALIZING, BenchmarkState.LOADING_MODELS)
|
||||
BenchmarkState.INITIALIZING -> listOf(BenchmarkState.RUNNING, BenchmarkState.ERROR)
|
||||
BenchmarkState.RUNNING -> listOf(BenchmarkState.STOPPING, BenchmarkState.COMPLETED, BenchmarkState.ERROR)
|
||||
|
|
|
|||
|
|
@ -356,6 +356,9 @@ class MarketItemHolder(
|
|||
R.id.menu_open_model_card -> {
|
||||
openModelCard(itemView.context, modelMarketItem)
|
||||
}
|
||||
R.id.menu_copy_error_info -> {
|
||||
copyErrorInfoToClipboard(modelMarketItemWrapper)
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
|
@ -388,6 +391,10 @@ class MarketItemHolder(
|
|||
|
||||
// Model card: always visible
|
||||
menu.findItem(R.id.menu_open_model_card).isVisible = false
|
||||
|
||||
// Copy error info: visible only for failed downloads
|
||||
menu.findItem(R.id.menu_copy_error_info).isVisible =
|
||||
downloadState == DownloadState.FAILED
|
||||
}
|
||||
|
||||
private fun handleSettingsMenu(modelMarketItem: ModelMarketItem) {
|
||||
|
|
@ -409,6 +416,41 @@ class MarketItemHolder(
|
|||
private fun openModelCard(context: android.content.Context, modelMarketItem: ModelMarketItem) {
|
||||
}
|
||||
|
||||
private fun copyErrorInfoToClipboard(modelMarketItemWrapper: ModelMarketItemWrapper) {
|
||||
val context = itemView.context
|
||||
val downloadInfo = modelMarketItemWrapper.downloadInfo
|
||||
val modelMarketItem = modelMarketItemWrapper.modelMarketItem
|
||||
|
||||
if (downloadInfo.downloadState != DownloadState.FAILED) {
|
||||
return
|
||||
}
|
||||
|
||||
val errorMessage = downloadInfo.errorMessage ?: "Unknown error"
|
||||
val modelInfo = "Model: ${modelMarketItem.modelName} (${modelMarketItem.modelId})"
|
||||
|
||||
// Build error info with stack trace if available
|
||||
val errorInfoBuilder = StringBuilder().apply {
|
||||
appendLine(modelInfo)
|
||||
appendLine("Error: $errorMessage")
|
||||
appendLine("Time: ${java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss", java.util.Locale.getDefault()).format(java.util.Date())}")
|
||||
|
||||
// Add stack trace if exception is available
|
||||
downloadInfo.errorException?.let { exception ->
|
||||
appendLine()
|
||||
appendLine("Stack Trace:")
|
||||
appendLine(android.util.Log.getStackTraceString(exception))
|
||||
}
|
||||
}
|
||||
|
||||
val errorInfo = errorInfoBuilder.toString()
|
||||
|
||||
val clipboard = context.getSystemService(android.content.Context.CLIPBOARD_SERVICE) as android.content.ClipboardManager
|
||||
val clip = android.content.ClipData.newPlainText("Error Info", errorInfo)
|
||||
clipboard.setPrimaryClip(clip)
|
||||
|
||||
Toast.makeText(context, context.getString(R.string.error_info_copied), Toast.LENGTH_SHORT).show()
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val TAG = "ModelItemHolder"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ data class ModelMarketItem(
|
|||
val categories: List<String>,
|
||||
val sources: Map<String, String>,
|
||||
val description: String? = null,
|
||||
@SerializedName("file_size") val fileSize: Long = 0L, // File size in bytes from model_market.json
|
||||
var currentSource: String = "", // e.g. "modelscope", "huggingface"
|
||||
var currentRepoPath: String = "", // e.g. "MNN/Qwen-1.8B-Chat-Int4"
|
||||
var modelId: String = "" // e.g. "ModelScope/MNN/Qwen-1.8B-Chat-Int4"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package com.alibaba.mnnllm.android.modelmarket
|
|||
import android.content.Context
|
||||
import android.util.Log
|
||||
import androidx.preference.PreferenceManager
|
||||
import com.alibaba.mls.api.download.DownloadPersistentData
|
||||
import com.alibaba.mnnllm.android.mainsettings.MainSettings
|
||||
import com.google.gson.Gson
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
|
|
@ -233,6 +234,16 @@ class ModelRepository(private val context: Context) {
|
|||
item.currentSource = selectedSource
|
||||
item.currentRepoPath = item.sources[selectedSource]!!
|
||||
item.modelId = "$selectedSource/${item.sources[selectedSource]!!}"
|
||||
|
||||
// Save file_size to DownloadPersistentData if available
|
||||
if (item.fileSize > 0) {
|
||||
try {
|
||||
DownloadPersistentData.saveMarketSizeTotalSuspend(context, item.modelId, item.fileSize)
|
||||
Log.d(TAG, "Saved market size for ${item.modelId}: ${item.fileSize}")
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to save market size for ${item.modelId}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,4 +19,7 @@
|
|||
<item
|
||||
android:id="@+id/menu_open_model_card"
|
||||
android:title="@string/open_model_card" />
|
||||
<item
|
||||
android:id="@+id/menu_copy_error_info"
|
||||
android:title="@string/menu_copy_error_info" />
|
||||
</menu>
|
||||
|
|
@ -485,4 +485,6 @@
|
|||
<string name="allow_network_market_data">允许联网获取模型市场数据</string>
|
||||
<string name="enable_network_delay">启用网络延时(3秒)</string>
|
||||
<string name="update_available">有新版本</string>
|
||||
<string name="menu_copy_error_info">复制错误信息</string>
|
||||
<string name="error_info_copied">错误信息已复制到剪贴板</string>
|
||||
</resources>
|
||||
|
|
|
|||
|
|
@ -488,4 +488,6 @@
|
|||
<string name="allow_network_market_data">Allow network to fetch model market data</string>
|
||||
<string name="enable_network_delay">Enable network delay (3 seconds)</string>
|
||||
<string name="update_available">Update Available</string>
|
||||
<string name="menu_copy_error_info">Copy Error Info</string>
|
||||
<string name="error_info_copied">Error information copied to clipboard</string>
|
||||
</resources>
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
import json
|
||||
import argparse
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
def get_repo_size_in_bytes(repo_id: str) -> int:
|
||||
"""
|
||||
Calculates the total size of all files in a Hugging Face model repository.
|
||||
|
||||
Args:
|
||||
repo_id (str): The ID of the model repository, e.g., 'google-bert/bert-base-uncased'.
|
||||
|
||||
Returns:
|
||||
int: The total size of the files in bytes. Returns -1 if the repository
|
||||
cannot be found or an error occurs.
|
||||
"""
|
||||
# Initialize the HfApi client
|
||||
api = HfApi()
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
# Fetch model information, including metadata for each file
|
||||
print(f"Fetching metadata for repository: '{repo_id}'...")
|
||||
model_info = api.model_info(repo_id=repo_id, files_metadata=True)
|
||||
|
||||
# Sum the size of each file in the repository
|
||||
for file in model_info.siblings:
|
||||
if file.size is not None:
|
||||
total_size += file.size
|
||||
|
||||
print(f"Successfully calculated size for '{repo_id}': {total_size} bytes")
|
||||
return total_size
|
||||
|
||||
except HfHubHTTPError as e:
|
||||
# Handle cases where the repository is not found or other HTTP errors
|
||||
print(f"Error: Could not retrieve info for repository '{repo_id}'. It might not exist or be private.")
|
||||
print(f"Details: {e}")
|
||||
return -1
|
||||
except Exception as e:
|
||||
# Handle other potential exceptions
|
||||
print(f"An unexpected error occurred while processing '{repo_id}': {e}")
|
||||
return -1
|
||||
|
||||
def process_model_list(models: list):
|
||||
"""
|
||||
Iterates through a list of model objects and adds the 'file_size' field.
|
||||
|
||||
Args:
|
||||
models (list): A list of dictionaries, where each dictionary represents a model.
|
||||
"""
|
||||
if not models:
|
||||
return # Do nothing if the list is empty
|
||||
|
||||
for model in models:
|
||||
# Check if the model has a HuggingFace source
|
||||
if 'sources' in model and 'HuggingFace' in model['sources']:
|
||||
repo_id = model['sources']['HuggingFace']
|
||||
if repo_id:
|
||||
# Get the repository size and add it to the model object
|
||||
file_size = get_repo_size_in_bytes(repo_id)
|
||||
model['file_size'] = file_size
|
||||
else:
|
||||
# Handle cases where the HuggingFace repo_id is empty
|
||||
model['file_size'] = -1
|
||||
print(f"Warning: Empty HuggingFace repo_id for model '{model.get('modelName', 'N/A')}'.")
|
||||
else:
|
||||
# If no HuggingFace source, you can decide what to do.
|
||||
# Here we skip adding the field, or you could add 'file_size': 0 or -1
|
||||
print(f"Skipping model '{model.get('modelName', 'N/A')}' as it has no HuggingFace source.")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to parse arguments, read the input JSON, process it,
|
||||
and write to the output JSON file.
|
||||
"""
|
||||
# Set up argument parser for command-line interface
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process a market config JSON file to add Hugging Face model sizes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i", "--input",
|
||||
required=True,
|
||||
help="Path to the input JSON file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output",
|
||||
required=True,
|
||||
help="Path to the output JSON file."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Read the input JSON file
|
||||
try:
|
||||
with open(args.input, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Input file not found at '{args.input}'")
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Could not decode JSON from the input file '{args.input}'.")
|
||||
return
|
||||
|
||||
# Define the keys that contain lists of models to process
|
||||
model_list_keys = ['models', 'tts_models', 'asr_models']
|
||||
|
||||
# Process each list of models
|
||||
for key in model_list_keys:
|
||||
if key in data and isinstance(data[key], list):
|
||||
print(f"\n--- Processing models in '{key}' ---")
|
||||
process_model_list(data[key])
|
||||
else:
|
||||
print(f"\n--- No models found in '{key}', skipping ---")
|
||||
|
||||
# Write the updated data to the output JSON file
|
||||
try:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
print(f"\nProcessing complete. Output successfully written to '{args.output}'")
|
||||
except IOError as e:
|
||||
print(f"Error: Could not write to output file '{args.output}'. Details: {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -102,7 +102,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
|
|||
// PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
|
||||
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
|
||||
|
||||
if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
|
||||
if self?.modelInfo.tags.contains("Think") == true,
|
||||
let text = self?.processor.process(progress: message.text) {
|
||||
updateLastMsg?.text = text
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ struct ChatHistoryItemView: View {
|
|||
|
||||
if let lastMessage = getLastNonEmptyMessage() {
|
||||
Text(String(lastMessage.content.prefix(200)))
|
||||
.lineLimit(1)
|
||||
.lineLimit(3)
|
||||
.font(.system(size: 15, weight: .medium))
|
||||
.foregroundColor(.primary)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -709,9 +709,13 @@ bool remove_directory_safely(const std::string& path) {
|
|||
}
|
||||
|
||||
std::string inputStr = [input UTF8String];
|
||||
#ifdef DEBUG
|
||||
if (inputStr == "benchmark") {
|
||||
[blockSelf performBenchmarkWithOutput:&os];
|
||||
} else {
|
||||
#else
|
||||
{
|
||||
#endif
|
||||
// Get initial context state for performance measurement
|
||||
auto context = blockSelf->_llm->getContext();
|
||||
int initial_prompt_len = context->prompt_len;
|
||||
|
|
|
|||
|
|
@ -1308,7 +1308,14 @@
|
|||
}
|
||||
},
|
||||
"Yes" : {
|
||||
|
||||
"localizations" : {
|
||||
"zh-Hans" : {
|
||||
"stringUnit" : {
|
||||
"state" : "translated",
|
||||
"value" : "是"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"搜索本地模型..." : {
|
||||
"localizations" : {
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ class BenchmarkService: ObservableObject {
|
|||
|
||||
/// Releases the current model and frees associated resources
|
||||
func releaseModel() {
|
||||
llmEngine?.cancelInference()
|
||||
llmEngine = nil
|
||||
currentModelId = nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class BenchmarkViewModel: ObservableObject {
|
|||
|
||||
@Published var startButtonText = String(localized: "Start Test")
|
||||
@Published var isStartButtonEnabled = true
|
||||
@Published var showStopConfirmation = false
|
||||
|
||||
// MARK: - Private Properties
|
||||
|
||||
|
|
@ -105,6 +106,7 @@ class BenchmarkViewModel: ObservableObject {
|
|||
|
||||
/// Handles benchmark stop confirmation
|
||||
func onStopBenchmarkTapped() {
|
||||
showStopConfirmation = false
|
||||
stopBenchmark()
|
||||
}
|
||||
|
||||
|
|
@ -119,8 +121,8 @@ class BenchmarkViewModel: ObservableObject {
|
|||
showResults = false
|
||||
hideStatus()
|
||||
|
||||
// Release model to free memory
|
||||
benchmarkService.releaseModel()
|
||||
// Clean up resources when deleting results
|
||||
cleanupBenchmarkResources()
|
||||
}
|
||||
|
||||
/// Placeholder for future result submission functionality
|
||||
|
|
@ -200,14 +202,16 @@ class BenchmarkViewModel: ObservableObject {
|
|||
private func stopBenchmark() {
|
||||
updateStatus("Stopping benchmark...")
|
||||
benchmarkService.stopBenchmark()
|
||||
MemoryMonitor.shared.stop()
|
||||
cleanupBenchmarkResources()
|
||||
}
|
||||
|
||||
// MARK: - UI State Management
|
||||
|
||||
/// Updates UI state when benchmark starts
|
||||
private func onBenchmarkStarted() {
|
||||
isRunning = true
|
||||
isStartButtonEnabled = true
|
||||
startButtonText = String(localized: "Stop Test")
|
||||
showProgressBar = true
|
||||
showResults = false
|
||||
updateStatus("Initializing benchmark...")
|
||||
|
|
@ -215,11 +219,22 @@ class BenchmarkViewModel: ObservableObject {
|
|||
|
||||
/// Resets UI to initial state
|
||||
private func resetUIState() {
|
||||
isRunning = false
|
||||
isStartButtonEnabled = true
|
||||
startButtonText = String(localized: "Start Test")
|
||||
showProgressBar = false
|
||||
hideStatus()
|
||||
showResults = false
|
||||
cleanupBenchmarkResources()
|
||||
}
|
||||
|
||||
/// Cleans up benchmark resources including memory monitoring and model
|
||||
private func cleanupBenchmarkResources() {
|
||||
MemoryMonitor.shared.stop()
|
||||
MemoryMonitor.shared.reset()
|
||||
|
||||
// Release model to free memory
|
||||
benchmarkService.releaseModel()
|
||||
}
|
||||
|
||||
/// Updates status message display
|
||||
|
|
@ -238,9 +253,9 @@ class BenchmarkViewModel: ObservableObject {
|
|||
showError = true
|
||||
}
|
||||
|
||||
/// Placeholder for stop confirmation alert (handled in View)
|
||||
/// Shows stop confirmation alert
|
||||
private func showStopConfirmationAlert() {
|
||||
// This will be handled in the View with an alert
|
||||
showStopConfirmation = true
|
||||
}
|
||||
|
||||
/// Formats progress messages with appropriate status text based on progress type
|
||||
|
|
@ -309,11 +324,14 @@ extension BenchmarkViewModel: BenchmarkCallback {
|
|||
benchmarkResults = results
|
||||
showResults = true
|
||||
|
||||
// Only stop memory monitoring if benchmark is no longer running (all tests completed)
|
||||
if !isRunning {
|
||||
// Stop memory monitoring
|
||||
MemoryMonitor.shared.stop()
|
||||
}
|
||||
// Update UI state to reflect completion
|
||||
isRunning = false
|
||||
isStartButtonEnabled = true
|
||||
startButtonText = String(localized: "Start Test")
|
||||
showProgressBar = false
|
||||
|
||||
// Clean up resources after benchmark completion
|
||||
cleanupBenchmarkResources()
|
||||
|
||||
// Always hide status after processing results
|
||||
hideStatus()
|
||||
|
|
@ -324,9 +342,16 @@ extension BenchmarkViewModel: BenchmarkCallback {
|
|||
/// Handles benchmark errors with user-friendly error messages
|
||||
func onBenchmarkError(_ errorCode: Int, _ message: String) {
|
||||
let errorCodeName = BenchmarkErrorCode(rawValue: errorCode)?.description ?? "Unknown"
|
||||
showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)")
|
||||
|
||||
// Check if this is a user-initiated stop - don't show error dialog
|
||||
if errorCode == BenchmarkErrorCode.benchmarkStopped.rawValue {
|
||||
print("BenchmarkViewModel: Benchmark stopped by user (\(errorCode)): \(message)")
|
||||
} else {
|
||||
showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)")
|
||||
print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)")
|
||||
}
|
||||
|
||||
resetUIState()
|
||||
print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import SwiftUI
|
|||
*/
|
||||
struct ModelSelectionCard: View {
|
||||
@ObservedObject var viewModel: BenchmarkViewModel
|
||||
@Binding var showStopConfirmation: Bool
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 16) {
|
||||
|
|
@ -140,11 +139,7 @@ struct ModelSelectionCard: View {
|
|||
|
||||
private var startStopButton: some View {
|
||||
Button(action: {
|
||||
if viewModel.startButtonText.contains("Stop") {
|
||||
showStopConfirmation = true
|
||||
} else {
|
||||
viewModel.onStartBenchmarkTapped()
|
||||
}
|
||||
viewModel.onStartBenchmarkTapped()
|
||||
}) {
|
||||
HStack(spacing: 12) {
|
||||
ZStack {
|
||||
|
|
@ -152,12 +147,12 @@ struct ModelSelectionCard: View {
|
|||
.fill(Color.white.opacity(0.2))
|
||||
.frame(width: 32, height: 32)
|
||||
|
||||
if viewModel.isRunning && viewModel.startButtonText.contains("Stop") {
|
||||
if viewModel.isRunning {
|
||||
ProgressView()
|
||||
.progressViewStyle(CircularProgressViewStyle(tint: .white))
|
||||
.scaleEffect(0.7)
|
||||
} else {
|
||||
Image(systemName: viewModel.startButtonText.contains("Stop") ? "stop.fill" : "play.fill")
|
||||
Image(systemName: viewModel.isRunning ? "stop.fill" : "play.fill")
|
||||
.font(.system(size: 16, weight: .bold))
|
||||
.foregroundColor(.white)
|
||||
}
|
||||
|
|
@ -169,7 +164,7 @@ struct ModelSelectionCard: View {
|
|||
|
||||
Spacer()
|
||||
|
||||
if !viewModel.startButtonText.contains("Stop") {
|
||||
if !viewModel.isRunning {
|
||||
Image(systemName: "arrow.right")
|
||||
.font(.system(size: 16, weight: .semibold))
|
||||
.foregroundColor(.white.opacity(0.8))
|
||||
|
|
@ -182,7 +177,7 @@ struct ModelSelectionCard: View {
|
|||
RoundedRectangle(cornerRadius: 16)
|
||||
.fill(
|
||||
viewModel.isStartButtonEnabled ?
|
||||
(viewModel.startButtonText.contains("Stop") ?
|
||||
(viewModel.isRunning ?
|
||||
LinearGradient(
|
||||
colors: [Color.benchmarkError, Color.benchmarkError.opacity(0.8)],
|
||||
startPoint: .leading,
|
||||
|
|
@ -234,8 +229,7 @@ struct ModelSelectionCard: View {
|
|||
|
||||
#Preview {
|
||||
ModelSelectionCard(
|
||||
viewModel: BenchmarkViewModel(),
|
||||
showStopConfirmation: .constant(false)
|
||||
viewModel: BenchmarkViewModel()
|
||||
)
|
||||
.padding()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import SwiftUI
|
|||
*/
|
||||
struct BenchmarkView: View {
|
||||
@StateObject private var viewModel = BenchmarkViewModel()
|
||||
@State private var showStopConfirmation = false
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
|
|
@ -21,8 +20,7 @@ struct BenchmarkView: View {
|
|||
VStack(spacing: 24) {
|
||||
// Model Selection Section
|
||||
ModelSelectionCard(
|
||||
viewModel: viewModel,
|
||||
showStopConfirmation: $showStopConfirmation
|
||||
viewModel: viewModel
|
||||
)
|
||||
|
||||
// Progress Section
|
||||
|
|
@ -55,7 +53,7 @@ struct BenchmarkView: View {
|
|||
.padding(.vertical, 16)
|
||||
}
|
||||
}
|
||||
.alert("Stop Benchmark", isPresented: $showStopConfirmation) {
|
||||
.alert("Stop Benchmark", isPresented: $viewModel.showStopConfirmation) {
|
||||
Button("Yes", role: .destructive) {
|
||||
viewModel.onStopBenchmarkTapped()
|
||||
}
|
||||
|
|
@ -68,11 +66,7 @@ struct BenchmarkView: View {
|
|||
} message: {
|
||||
Text(viewModel.errorMessage)
|
||||
}
|
||||
.onReceive(viewModel.$isRunning) { isRunning in
|
||||
if isRunning && viewModel.startButtonText.contains("Stop") {
|
||||
showStopConfirmation = false
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ struct LocalModelListView: View {
|
|||
viewModel.selectModel(model)
|
||||
}) {
|
||||
LocalModelRowView(model: model)
|
||||
.contentShape(Rectangle())
|
||||
}
|
||||
.buttonStyle(PlainButtonStyle())
|
||||
.listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
|
||||
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
|
||||
SwipeActionsView(model: model, viewModel: viewModel)
|
||||
|
|
|
|||
|
|
@ -59,6 +59,10 @@ struct LocalModelRowView: View {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
Spacer()
|
||||
}
|
||||
.padding(.vertical, 8)
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,7 +142,10 @@ struct MainTabView: View {
|
|||
.edgesIgnoringSafeArea(.all)
|
||||
}
|
||||
.onChange(of: showHistory) { oldValue, newValue in
|
||||
if !newValue {
|
||||
if newValue {
|
||||
// Refresh chat history when opening the side menu
|
||||
histories = ChatHistoryManager.shared.getAllHistory()
|
||||
} else {
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
|
||||
withAnimation {
|
||||
showHistoryButton = true
|
||||
|
|
@ -191,6 +194,9 @@ struct MainTabView: View {
|
|||
modelListViewModel.recordModelUsage(modelName: model.modelName)
|
||||
}
|
||||
|
||||
// Refresh chat history when returning from chat
|
||||
histories = ChatHistoryManager.shared.getAllHistory()
|
||||
|
||||
// Clear selections
|
||||
modelListViewModel.selectedModel = nil
|
||||
selectedHistory = nil
|
||||
|
|
|
|||
|
|
@ -67,7 +67,14 @@ struct ModelInfo: Codable {
|
|||
}
|
||||
|
||||
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
|
||||
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
||||
let baseId = sources[sourceKey] ?? "taobao-mnn/\(modelName)"
|
||||
|
||||
// Add vendor prefix to ensure uniqueness for local models
|
||||
if vendor == "Local" {
|
||||
return "local/\(modelName)"
|
||||
}
|
||||
|
||||
return baseId
|
||||
}
|
||||
|
||||
var localizedTags: [String] {
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class ModelListViewModel: ObservableObject {
|
|||
if !foundModelFiles.isEmpty {
|
||||
// Check if we have a complete model (at least config.json)
|
||||
if foundModelFiles.contains("config.json") {
|
||||
let modelName = "Qwen3-0.6B-MNN"
|
||||
let modelName = "Qwen3-0.6B-MNN-Inside"
|
||||
let localModel = ModelInfo(
|
||||
modelName: modelName,
|
||||
tags: [NSLocalizedString("tag.deepThinking", comment: "Deep thinking tag for local model"),
|
||||
|
|
@ -143,8 +143,10 @@ class ModelListViewModel: ObservableObject {
|
|||
let itemConfigPath = (itemPath as NSString).appendingPathComponent("config.json")
|
||||
|
||||
if fileManager.fileExists(atPath: itemConfigPath) {
|
||||
// Use custom name for Qwen3-0.6B-MNN to avoid conflicts
|
||||
let modelName = item == "Qwen3-0.6B-MNN" ? "Qwen3-0.6B-MNN-Inside" : item
|
||||
let localModel = ModelInfo(
|
||||
modelName: item,
|
||||
modelName: modelName,
|
||||
tags: ["local", "bundled"],
|
||||
categories: ["Local Models"],
|
||||
vendor: "Local",
|
||||
|
|
@ -153,7 +155,7 @@ class ModelListViewModel: ObservableObject {
|
|||
)
|
||||
localModels.append(localModel)
|
||||
|
||||
ModelStorageManager.shared.markModelAsDownloaded(item)
|
||||
ModelStorageManager.shared.markModelAsDownloaded(modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -175,12 +177,15 @@ class ModelListViewModel: ObservableObject {
|
|||
|
||||
var fetchedModels = info.models
|
||||
|
||||
// Add LocalModel folder models
|
||||
// Add LocalModel folder models, avoiding duplicates
|
||||
let localModels = await loadLocalModels()
|
||||
fetchedModels.append(contentsOf: localModels)
|
||||
let existingModelNames = Set(fetchedModels.map { $0.modelName })
|
||||
let uniqueLocalModels = localModels.filter { !existingModelNames.contains($0.modelName) }
|
||||
fetchedModels.append(contentsOf: uniqueLocalModels)
|
||||
|
||||
filterDiffusionModels(fetchedModels: &fetchedModels)
|
||||
loadCachedSizes(for: &fetchedModels)
|
||||
syncDownloadStatus(for: &fetchedModels)
|
||||
sortModels(fetchedModels: &fetchedModels)
|
||||
self.models = fetchedModels
|
||||
|
||||
|
|
@ -203,6 +208,18 @@ class ModelListViewModel: ObservableObject {
|
|||
}
|
||||
}
|
||||
|
||||
private func syncDownloadStatus(for models: inout [ModelInfo]) {
|
||||
for i in 0..<models.count {
|
||||
let isDownloaded = ModelStorageManager.shared.isModelDownloaded(models[i].modelName)
|
||||
models[i].isDownloaded = isDownloaded
|
||||
|
||||
// Also sync last used date
|
||||
if let lastUsed = ModelStorageManager.shared.getLastUsed(for: models[i].modelName) {
|
||||
models[i].lastUsedAt = lastUsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchModelSizes(for models: [ModelInfo]) async {
|
||||
await withTaskGroup(of: Void.self) { group in
|
||||
for (_, model) in models.enumerated() {
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ static inline uint64_t getTimeInUs() {
|
|||
}
|
||||
|
||||
std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward = MNN_FORWARD_CPU, bool only_inference = true,
|
||||
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false) {
|
||||
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false, bool enableKleidiAI=false) {
|
||||
auto revertor = std::unique_ptr<Revert>(new Revert(model.model_file.c_str()));
|
||||
if (testQuantModel) {
|
||||
revertor->initialize(0, sparseBlockOC, false, true);
|
||||
|
|
@ -130,6 +130,7 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
|
|||
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize), MNN::Interpreter::destroy);
|
||||
revertor.reset();
|
||||
net->setSessionMode(MNN::Interpreter::Session_Release);
|
||||
net->setSessionHint(MNN::Interpreter::HintMode::CPU_ENABLE_KLEIDIAI, enableKleidiAI);
|
||||
MNN::ScheduleConfig config;
|
||||
config.numThread = numberThread;
|
||||
config.type = static_cast<MNNForwardType>(forward);
|
||||
|
|
@ -392,8 +393,9 @@ int main(int argc, const char* argv[]) {
|
|||
int precision = 2;
|
||||
float sparsity = 0.0f;
|
||||
int sparseBlockOC = 1;
|
||||
bool enableKleidiAI = false;
|
||||
if (argc <= 2) {
|
||||
std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel]" << std::endl;
|
||||
std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel] [enableKleidiAI]" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
if (argc >= 3) {
|
||||
|
|
@ -420,8 +422,11 @@ int main(int argc, const char* argv[]) {
|
|||
if(argc >= 10) {
|
||||
testQuantizedModel = atoi(argv[9]);
|
||||
}
|
||||
if (argc >= 11) {
|
||||
enableKleidiAI = atoi(argv[10]) > 0 ? true : false;
|
||||
}
|
||||
|
||||
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << std::endl;
|
||||
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << " enableKleidiAI=" << enableKleidiAI << std::endl;
|
||||
std::vector<Model> models = findModelFiles(argv[1]);
|
||||
|
||||
std::cout << "--------> Benchmarking... loop = " << argv[2] << ", warmup = " << warmup << std::endl;
|
||||
|
|
@ -441,10 +446,10 @@ int main(int argc, const char* argv[]) {
|
|||
}
|
||||
|
||||
for (auto& m : models) {
|
||||
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false);
|
||||
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false, enableKleidiAI);
|
||||
displayStats(m.name.c_str(), costs, false);
|
||||
if (testQuantizedModel) {
|
||||
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true);
|
||||
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true, enableKleidiAI);
|
||||
displayStats(m.name, costs, 1);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,5 +98,5 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_SUPPORT_TRANSFORMER_FUSE | 是否支持Fuse Transformer相关OP实现,默认为 `OFF` |
|
||||
| MNN_BUILD_LLM | 是否构建基于MNN的llm库和demo,默认为`OFF` |
|
||||
| MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo,需要打开MNN_BUILD_OPENCV和MNN_IMGCODECS宏使用 默认为`OFF` |
|
||||
| MNN_KLEIDIAI | 是否集成ARM的klediAI加速库【目前处于实验状态,只能跑对称量化的LLM模型】,默认为`OFF` |
|
||||
| MNN_KLEIDIAI | 是否集成ARM的klediAI加速库,默认为`ON` |
|
||||
| MNN_USE_RVV | 是否启用RISC-V向量扩展支持,默认为`OFF` |
|
||||
|
|
|
|||
|
|
@ -252,7 +252,10 @@ public:
|
|||
CPU_CORE_IDS = 14,
|
||||
|
||||
// set CPU threads to use when supports Arm sme2
|
||||
CPU_SME2_INSTRUCTIONS = 15
|
||||
CPU_SME2_INSTRUCTIONS = 15,
|
||||
|
||||
// Enable KleidiAI
|
||||
CPU_ENABLE_KLEIDIAI = 16
|
||||
};
|
||||
|
||||
enum ExternalPathType {
|
||||
|
|
|
|||
|
|
@ -21,10 +21,6 @@
|
|||
#include "ThreadPool.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include "arm/mnn_kleidiai.h"
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
class WorkerThread;
|
||||
class CPURuntime : public Runtime {
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ if (MNN_CPU_WEIGHT_DEQUANT_GEMM)
|
|||
FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/normal_memory/*.[sS])
|
||||
endif()
|
||||
|
||||
if (MNN_KLEIDIAI)
|
||||
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
|
||||
if (MNN_KLEIDIAI AND CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64")
|
||||
# Disable the KleidiAI tests
|
||||
set(KLEIDIAI_BUILD_TESTS OFF)
|
||||
# Fetch KleidiAI sources:
|
||||
|
|
@ -121,6 +120,9 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?")
|
|||
endif()
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64")
|
||||
message(STATUS "Enabling AArch64 Assemblies")
|
||||
if (MNN_KLEIDIAI)
|
||||
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
|
||||
endif()
|
||||
if (MNN_SME2)
|
||||
add_definitions(-DMNN_SME2)
|
||||
FILE(GLOB MNN_SME2_AArch64_SRC ${MNN_SME2_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/sme2_asm/*.[sS])
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include "core/OpCommonUtils.hpp"
|
||||
#include "backend/cpu/OneDNNConvolution.hpp"
|
||||
#include "backend/cpu/compute/ConvInt8TiledExecutor.hpp"
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include "backend/cpu/compute/KleidiAIConvInt8.hpp"
|
||||
#include "backend/cpu/compute/KleidiAIConvolution.hpp"
|
||||
|
|
@ -31,6 +32,92 @@
|
|||
|
||||
namespace MNN {
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
static Execution* _createKleidiAIUnit(const Tensor* input, const Tensor* output, Backend* backend, const Op* op,
|
||||
const float* originWeight, size_t originWeightSize, const float* bias,
|
||||
size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo,
|
||||
bool supportSparse, bool lowMemory) {
|
||||
auto cpuBackend = (CPUBackend*)backend;
|
||||
auto conv2d = op->main_as_Convolution2D();
|
||||
auto common = conv2d->common();
|
||||
|
||||
bool fastWay = common->kernelY() == 1 && common->kernelX() == 1 && output->width() == input->width() &&
|
||||
output->height() == input->height() && common->strideX() == 1 && common->strideY() == 1;
|
||||
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) {
|
||||
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
|
||||
do {
|
||||
if (!weightQuantInfo->canUseInt4) {
|
||||
break;
|
||||
}
|
||||
auto convOp = op->main_as_Convolution2D();
|
||||
auto core = static_cast<CPUBackend*>(backend)->functions();
|
||||
int oc = convOp->common()->outputCount();
|
||||
int ic = convOp->common()->inputCount();
|
||||
|
||||
int blockNum = 1;
|
||||
int dequantCnt = weightQuantInfo->alphaSize;
|
||||
if (weightQuantInfo->asymmetric) {
|
||||
dequantCnt /= 2;
|
||||
}
|
||||
blockNum = dequantCnt / oc;
|
||||
|
||||
bool bAsym = weightQuantInfo->asymmetric;
|
||||
size_t blkSize = blockNum == 1 ? 0 : ic / blockNum;
|
||||
|
||||
KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes);
|
||||
|
||||
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
|
||||
if (!kai.canAccelerate(accelType, convOp->common())) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!kai.isLoaded(accelType)) {
|
||||
kai.setLoaded(accelType);
|
||||
kai.printInfo(accelType);
|
||||
}
|
||||
|
||||
return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum);
|
||||
} while (0);
|
||||
}
|
||||
|
||||
// Have not supported the quantized weight.
|
||||
return nullptr;
|
||||
}
|
||||
#else
|
||||
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
|
||||
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize,
|
||||
weightQuantInfo);
|
||||
}
|
||||
|
||||
// Do nothing and fallback.
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// This is different with original impl. It's a corresponding impl for strassen,
|
||||
// which is called when built without MNN_REDUCE_SIZE. But for KleidiAI,
|
||||
// need not to care about this.
|
||||
if (fastWay && cpuBackend->functions()->matmulBytes == 0) {
|
||||
auto bytes = cpuBackend->functions()->bytes;
|
||||
auto accelType = (bytes == 2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32;
|
||||
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
|
||||
if (kai.canAccelerate(accelType)) {
|
||||
return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize);
|
||||
}
|
||||
}
|
||||
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
|
||||
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize,
|
||||
weightQuantInfo);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
||||
static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend* backend,
|
||||
const Op* op, const float* originWeight, size_t originWeightSize, const float* bias, size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo, bool supportSparse, bool lowMemory) {
|
||||
auto cpuBackend = (CPUBackend*)backend;
|
||||
|
|
@ -48,47 +135,24 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
|
|||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
if (cpuBackend->getRuntime()->hint().enableKleidiAI) {
|
||||
auto execution = _createKleidiAIUnit(input, output, backend, op, originWeight, originWeightSize, bias, biasSize,
|
||||
weightQuantInfo, supportSparse, lowMemory);
|
||||
|
||||
if (execution) {
|
||||
return execution;
|
||||
}
|
||||
}
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
||||
bool fastWay = common->kernelY() == 1 && common->kernelX() == 1
|
||||
&& output->width() == input->width() && output->height() == input->height()
|
||||
&& common->strideX() == 1 && common->strideY() == 1;
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) {
|
||||
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
do {
|
||||
if (!weightQuantInfo->canUseInt4) {
|
||||
break;
|
||||
}
|
||||
auto convOp = op->main_as_Convolution2D();
|
||||
auto core = static_cast<CPUBackend*>(backend)->functions();
|
||||
int oc = convOp->common()->outputCount();
|
||||
int ic = convOp->common()->inputCount();
|
||||
|
||||
int blockNum = 1;
|
||||
int dequantCnt = weightQuantInfo->alphaSize;
|
||||
if (weightQuantInfo->asymmetric) {
|
||||
dequantCnt /= 2;
|
||||
}
|
||||
blockNum = dequantCnt / oc;
|
||||
|
||||
bool bAsym = weightQuantInfo->asymmetric;
|
||||
size_t blkSize = blockNum == 1 ? 0 : ic / blockNum;
|
||||
|
||||
KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes);
|
||||
|
||||
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
|
||||
if(!kai.isLoaded(accelType)) {
|
||||
kai.setLoaded(accelType);
|
||||
kai.printInfo(accelType);
|
||||
}
|
||||
|
||||
if(!kai.canAccelerate(accelType, convOp->common())){
|
||||
break;
|
||||
}
|
||||
return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum);
|
||||
} while (0);
|
||||
#endif
|
||||
|
||||
return new DenseConvInt8TiledExecutor(backend, op, weightQuantInfo, true);
|
||||
} else {
|
||||
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
|
|
@ -96,37 +160,16 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
|
|||
}
|
||||
#else
|
||||
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
|
||||
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
}
|
||||
#endif
|
||||
|
||||
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef MNN_REDUCE_SIZE
|
||||
if (fastWay && cpuBackend->functions()->matmulBytes == 0) {
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
auto bytes = cpuBackend->functions()->bytes;
|
||||
auto accelType = (bytes==2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32;
|
||||
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
|
||||
if (kai.canAccelerate(accelType)){
|
||||
return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize);
|
||||
}
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
||||
return new Convolution1x1Strassen(common, backend, originWeight, originWeightSize, bias, biasSize);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
|
||||
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (cpuBackend->getRuntime()->hint().winogradMemoryUsed == 0 || (!ConvolutionWinogradBridge::canUseWinograd(common))) {
|
||||
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, nullptr);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -303,4 +303,4 @@ ErrorCode KleidiAIConvInt8::onExecute(const std::vector<Tensor*>& inputs, const
|
|||
}
|
||||
|
||||
} // namespace MNN
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
|
|
|||
|
|
@ -6,10 +6,8 @@
|
|||
|
||||
#ifndef KleidiAIConvInt8_hpp
|
||||
#define KleidiAIConvInt8_hpp
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include "backend/cpu/CPUConvolution.hpp"
|
||||
#include "Int8FunctionsOpt.h"
|
||||
#include "CommonOptFunction.h"
|
||||
#include "backend/cpu/arm/mnn_kleidiai.h"
|
||||
|
||||
namespace MNN {
|
||||
class KleidiAIConvInt8 : public CPUConvolution {
|
||||
|
|
@ -31,5 +29,4 @@ private:
|
|||
};
|
||||
|
||||
} // namespace MNN
|
||||
#endif // MNN_KLEIDIAI_ENABLED
|
||||
#endif /* KleidiAIConvInt8_hpp */
|
||||
#endif /* KleidiAIConvInt8_hpp */
|
||||
|
|
|
|||
|
|
@ -3,19 +3,15 @@
|
|||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include "KleidiAIConvolution.hpp"
|
||||
#include <string.h>
|
||||
#include "core/BufferAllocator.hpp"
|
||||
#include "backend/cpu/CPUBackend.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "backend/cpu/CPUTensorConvert.hpp"
|
||||
|
||||
namespace MNN {
|
||||
#ifndef MNN_REDUCE_SIZE
|
||||
|
||||
KleidiAIConvolution::KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight,
|
||||
size_t originWeightSize, const float *bias, size_t biasSize)
|
||||
: CPUConvolution(common, b) {
|
||||
|
|
@ -228,7 +224,5 @@ ErrorCode KleidiAIConvolution::onExecute(const std::vector<Tensor *> &inputs, co
|
|||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace MNN
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
|
|
|||
|
|
@ -6,12 +6,9 @@
|
|||
|
||||
#ifndef KleidiAIConvolution_hpp
|
||||
#define KleidiAIConvolution_hpp
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include <functional>
|
||||
#include "backend/cpu/CPUConvolution.hpp"
|
||||
#include "backend/cpu/arm/mnn_kleidiai.h"
|
||||
namespace MNN {
|
||||
#ifndef MNN_REDUCE_SIZE
|
||||
|
||||
class KleidiAIConvolution : public CPUConvolution{
|
||||
public:
|
||||
KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight, size_t originWeightSize, const float *bias, size_t biasSize);
|
||||
|
|
@ -30,8 +27,5 @@ class KleidiAIConvolution : public CPUConvolution{
|
|||
KleidiAI::AccelType mAccelType = KleidiAI::AccelType::ACC_TYPE_NUMBER;
|
||||
std::vector<float> mPostParameters;
|
||||
};
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
||||
} // namespace MNN
|
||||
#endif
|
||||
#endif /* KleidiAIConvolution_hpp */
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#if MNN_KLEIDIAI_ENABLED
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include "KleidiAIDenseConvolution.hpp"
|
||||
|
||||
#include <numeric>
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
#include "backend/cpu/CPUTensorConvert.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||
#include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
|
||||
#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
|
||||
|
|
@ -304,10 +305,15 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
|
|||
.dilatedWidth = mCommon->dilateX(),
|
||||
};
|
||||
|
||||
mFunction.second = [=](int tid) {
|
||||
int threadNum = static_cast<CPUBackend*>(backend())->threadNumber();
|
||||
mFunction.second = [=](int tId) {
|
||||
// Convert NC4HW4 to NHWC
|
||||
auto inputShape = input->shape(); // TODO check for NC4HW4, should be the NCHW
|
||||
CPUTensorConverter::convert(input, &mInputNHWC, core);
|
||||
// CPUTensorConverter::convert(input, &mInputNHWC, core);
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
CPUTensorConverter::convert(input, &mInputNHWC, core, tId, threadNum);
|
||||
};
|
||||
MNN_CONCURRENCY_END();
|
||||
// Lhs packing
|
||||
if (bytes == 4) {
|
||||
int blockSize = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme();
|
||||
|
|
@ -348,7 +354,11 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
|
|||
}
|
||||
|
||||
// Convert NHWC to NC4HW4
|
||||
CPUTensorConverter::convert(&mOutputNHWC, output, core);
|
||||
// CPUTensorConverter::convert(&mOutputNHWC, output, core);
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
CPUTensorConverter::convert(&mOutputNHWC, output, core, tId, threadNum);
|
||||
};
|
||||
MNN_CONCURRENCY_END();
|
||||
};
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
|
@ -359,4 +369,4 @@ ErrorCode KleidiAIDenseConvolutionImpl::onExecute(const std::vector<Tensor*>& in
|
|||
return NO_ERROR;
|
||||
}
|
||||
} // namespace MNN
|
||||
#endif
|
||||
#endif //MNN_KLEIDIAI_ENABLED
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
#if MNN_KLEIDIAI_ENABLED
|
||||
|
||||
#ifndef KleidiAIDenseConvolution_hpp
|
||||
#define KleidiAIDenseConvolution_hpp
|
||||
|
||||
|
|
@ -242,4 +240,3 @@ private:
|
|||
} // namespace MNN
|
||||
|
||||
#endif /* KleidiAIDenseConvolution_hpp */
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ struct RuntimeHint {
|
|||
// whether to use Arm sme2 cores when threads>1
|
||||
bool useArmSme2Cores = true;
|
||||
|
||||
bool enableKleidiAI = false;
|
||||
|
||||
// Use CPU Ids
|
||||
std::vector<int> cpuIds;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -109,6 +109,9 @@ void Session::ModeGroup::setHint(Interpreter::HintMode hint, int value) {
|
|||
case Interpreter::HintMode::INIT_THREAD_NUMBER:
|
||||
runtimeHint.initThreadNumber = value;
|
||||
break;
|
||||
case Interpreter::HintMode::CPU_ENABLE_KLEIDIAI:
|
||||
runtimeHint.enableKleidiAI = value > 0 ? true : false;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,10 +19,6 @@
|
|||
#undef CONSTANT
|
||||
#endif // CONSTANT
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
#include "../backend/cpu/arm/mnn_kleidiai.h"
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
struct TensorArrayAttr {
|
||||
// array size is dynamic or not
|
||||
|
|
|
|||
|
|
@ -92,3 +92,7 @@ MNNForwardType getCurrentType() {
|
|||
return attr->firstType;
|
||||
}
|
||||
|
||||
std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor() {
|
||||
auto attr = MNN::Express::ExecutorScope::Current()->getAttr();
|
||||
return MNN::Express::Executor::newExecutor(getCurrentType(), attr->config, attr->numThread);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,6 +104,8 @@ inline float keepFP32Precision(float fp32Value) {
|
|||
}
|
||||
MNNForwardType getCurrentType();
|
||||
|
||||
std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor();
|
||||
|
||||
using ConvertFP32 = float(*)(float fp32Value);
|
||||
|
||||
const static std::vector<ConvertFP32> FP32Converter = {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,12 @@ public:
|
|||
for (int i = 0; i < channel * height * width; ++i){
|
||||
inputData[i] = (rand() % 10) * 0.1;
|
||||
}
|
||||
|
||||
MNN::BackendConfig config;
|
||||
config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal;
|
||||
config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal;
|
||||
std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4));
|
||||
ExecutorScope scope(executor);
|
||||
|
||||
auto net = _createModel();
|
||||
auto x = _Input({1, channel, height, width}, NCHW, halide_type_of<float>());
|
||||
|
|
@ -65,11 +71,7 @@ public:
|
|||
|
||||
|
||||
// clone model
|
||||
MNN::BackendConfig config;
|
||||
config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal;
|
||||
config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal;
|
||||
std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4));
|
||||
ExecutorScope scope(executor);
|
||||
|
||||
std::unique_ptr<Module> tempModule(Module::clone(net.get()));
|
||||
|
||||
auto xClone = _Input({1, channel, height, width}, NCHW, halide_type_of<float>());
|
||||
|
|
|
|||
|
|
@ -14,12 +14,16 @@
|
|||
#include <MNN/expr/Module.hpp>
|
||||
#include "MNNTestSuite.h"
|
||||
#include "MNN_generated.h"
|
||||
#include "TestUtils.h"
|
||||
using namespace MNN;
|
||||
using namespace MNN::Express;
|
||||
|
||||
class GatherExprTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
|
||||
std::unique_ptr<MNN::OpT> gatherOp(new MNN::OpT);
|
||||
gatherOp->type = MNN::OpType_GatherND;
|
||||
auto parameter = _Input({2, 2}, NHWC, halide_type_of<int32_t>());
|
||||
|
|
@ -224,7 +228,8 @@ public:
|
|||
class GatherNdReComputeTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
const float inpudata[] = {-1.0, -2.0, 3.0, 4.0};
|
||||
const int indices_data[] = {0, 0, 1, 1};
|
||||
auto params = _Const(inpudata, {2, 2}, NHWC, halide_type_of<float>());
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ public:
|
|||
MNN_ERROR("Currently don't test not cpu mmap\n");
|
||||
return true;
|
||||
}
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
auto x = _Input({1, 3, 224, 224}, NC4HW4, halide_type_of<float>());
|
||||
x->setName("x");
|
||||
auto y = _Conv(1.0f, 0.01f, x, {3, 16}, {5, 5});
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <random>
|
||||
#include "MNNTestSuite.h"
|
||||
#include "MNN_generated.h"
|
||||
#include "TestUtils.h"
|
||||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include <MNN/expr/Module.hpp>
|
||||
|
|
@ -66,6 +67,8 @@ static void _originMatMul(float* C, const float* A, const float* B, int e, int l
|
|||
class MatMulTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
int e = 5, h = 4, l = 6;
|
||||
if (true) {
|
||||
// Test MatMul
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include <MNN/expr/Module.hpp>
|
||||
#include "MNNTestSuite.h"
|
||||
#include "TestUtils.h"
|
||||
using namespace MNN;
|
||||
using namespace MNN::Express;
|
||||
|
||||
|
|
@ -15,6 +16,8 @@ public:
|
|||
return summer;
|
||||
}
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
std::vector<VARP> empty;
|
||||
// Make Net
|
||||
auto x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>());
|
||||
|
|
|
|||
|
|
@ -177,6 +177,8 @@ MNNTestSuiteRegister(ModuleTest, "expr/ModuleTest");
|
|||
class ModuleWrongInputTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
std::vector<int8_t> buffer;
|
||||
// construct
|
||||
{
|
||||
|
|
@ -244,6 +246,8 @@ MNNTestSuiteRegister(ModuleWrongInputTest, "expr/ModuleWrongInputTest");
|
|||
class RefTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
std::vector<int8_t> buffer;
|
||||
// construct
|
||||
{
|
||||
|
|
@ -318,6 +322,8 @@ public:
|
|||
}
|
||||
}
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
std::vector<int8_t> buffer;
|
||||
#ifdef MNN_REDUCE_SIZE
|
||||
return true;
|
||||
|
|
@ -1039,6 +1045,8 @@ MNNTestSuiteRegister(ConstMemoryReplaceTest, "expr/ConstMemoryReplaceTest");
|
|||
class MutlThreadConstReplaceTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
auto func = [precision](VARP y, int thread) {
|
||||
flatbuffers::FlatBufferBuilder builderOutput(1024);
|
||||
{
|
||||
|
|
@ -1499,6 +1507,8 @@ MNNTestSuiteRegister(ExecutorResetLoadModuleTest, "expr/ExecutorResetLoadModuleT
|
|||
class SequenceForwardResizeTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
// Make Model include convolution in shape compute and content compute
|
||||
auto x = _Input({1, 3, 24, 24}, NCHW, halide_type_of<float>());
|
||||
x->setName("x");
|
||||
|
|
@ -1606,6 +1616,8 @@ MNNTestSuiteRegister(SequenceForwardResizeTest, "expr/SequenceForwardResizeTest"
|
|||
class InputModuleTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
auto y = _mobileNetV1Expr(nullptr, false);
|
||||
std::unique_ptr<MNN::NetT> net(new NetT);
|
||||
Variable::save({y}, net.get());
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ static std::shared_ptr<Module> _createModel() {
|
|||
class RasterOutputTest : public MNNTestCase {
|
||||
public:
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
auto net = _createModel();
|
||||
auto x = _Input({1, 3, 224, 224}, NCHW, halide_type_of<int>());
|
||||
auto y = _Transpose(x, {0, 1, 3, 2});
|
||||
|
|
|
|||
|
|
@ -66,6 +66,11 @@ int main(int argc, char* argv[]) {
|
|||
dynamicOption = atoi(argv[7]);
|
||||
FUNC_PRINT(dynamicOption);
|
||||
}
|
||||
bool enableKleidiAI = false;
|
||||
if (argc > 8) {
|
||||
enableKleidiAI = atoi(argv[8]) > 0 ? true : false;
|
||||
FUNC_PRINT(enableKleidiAI);
|
||||
}
|
||||
auto exe = MNN::Express::Executor::newExecutor(type, config, thread);
|
||||
if (exe == nullptr) {
|
||||
MNN_ERROR("Can't create executor with type:%d, exit!\n", type);
|
||||
|
|
@ -76,6 +81,7 @@ int main(int argc, char* argv[]) {
|
|||
// set hint
|
||||
MNN::RuntimeHint hint;
|
||||
hint.dynamicQuantOption = dynamicOption;
|
||||
hint.enableKleidiAI = enableKleidiAI;
|
||||
scope.Current()->getRuntime().second->setRuntimeHint(hint);
|
||||
MNNTestSuite::get()->pStaus.memory = memory;
|
||||
MNNTestSuite::get()->pStaus.precision = precision;
|
||||
|
|
|
|||
|
|
@ -202,6 +202,8 @@ class BinaryBroadcastTest : public MNNTestCase {
|
|||
virtual ~BinaryBroadcastTest() = default;
|
||||
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
bool resultNCHW = testDimensionFormat(NCHW, precision);
|
||||
bool resultNHWC = testDimensionFormat(NHWC, precision);
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ class ReverseTest : public MNNTestCase {
|
|||
public:
|
||||
virtual ~ReverseTest() = default;
|
||||
virtual bool run(int precision) {
|
||||
auto executor = cloneCurrentExecutor();
|
||||
ExecutorScope scope(executor);
|
||||
std::shared_ptr<MNN::Express::Module> net;
|
||||
{
|
||||
auto input = _Input({3, 2, 3}, NCHW);
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ static inline std::vector<int> parseIntList(const std::string& str, char delim)
|
|||
int main(int argc, char *argv[]) {
|
||||
if (argc < 3) {
|
||||
MNN_PRINT("=======================================================================================================================================\n");
|
||||
MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds]\n");
|
||||
MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds] [enableKleidiAI]\n");
|
||||
MNN_PRINT("=======================================================================================================================================\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -247,11 +247,16 @@ int main(int argc, char *argv[]) {
|
|||
for (auto id : cpuIds) {
|
||||
MNN_PRINT("%d ", id);
|
||||
}
|
||||
bool enableKleidiAI = false;
|
||||
if (argc > 10) {
|
||||
enableKleidiAI = atoi(argv[10]) > 0 ? true : false;
|
||||
}
|
||||
MNN_PRINT("\n");
|
||||
FUNC_PRINT(precision);
|
||||
FUNC_PRINT(memory);
|
||||
FUNC_PRINT(power);
|
||||
FUNC_PRINT_ALL(cacheFileName, s);
|
||||
FUNC_PRINT(enableKleidiAI);
|
||||
// create session
|
||||
MNN::ScheduleConfig config;
|
||||
config.type = type;
|
||||
|
|
@ -320,6 +325,10 @@ int main(int argc, char *argv[]) {
|
|||
rtmgr->setHint(Interpreter::DYNAMIC_QUANT_OPTIONS, 2);
|
||||
}
|
||||
|
||||
if (enableKleidiAI) {
|
||||
rtmgr->setHint(Interpreter::CPU_ENABLE_KLEIDIAI, true);
|
||||
}
|
||||
|
||||
// rtmgr->setHint(Interpreter::CPU_SME2_INSTRUCTIONS, false);
|
||||
|
||||
if (runMask & 2048) {
|
||||
|
|
|
|||
|
|
@ -645,18 +645,21 @@ VARP Omni::gen_position_ids(int seq_len) {
|
|||
positionIds = _Input({3, seq_len}, NCHW, halide_type_of<int>());
|
||||
}
|
||||
auto ptr = positionIds->writeMap<int>();
|
||||
if (mContext->gen_seq_len > 0) {
|
||||
for (int i=0; i<seq_len; ++i) {
|
||||
auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
|
||||
ptr[i + 0] = pos;
|
||||
ptr[i + seq_len] = pos;
|
||||
ptr[i + seq_len * 2] = pos;
|
||||
}
|
||||
if (seq_len == 1) {
|
||||
ptr[0] = mContext->gen_seq_len + mPositionIds.back();
|
||||
ptr[1] = ptr[0];
|
||||
ptr[2] = ptr[0];
|
||||
} else {
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
ptr[i] = mPositionIds.mT[i];
|
||||
ptr[i + seq_len] = mPositionIds.mH[i];
|
||||
ptr[i + seq_len * 2] = mPositionIds.mW[i];
|
||||
if (mPositionIds.mT.size() != seq_len) {
|
||||
ptr[i] = i;
|
||||
ptr[i + seq_len] = i;
|
||||
ptr[i + seq_len * 2] = i;
|
||||
} else {
|
||||
ptr[i] = mPositionIds.mT[i];
|
||||
ptr[i + seq_len] = mPositionIds.mH[i];
|
||||
ptr[i + seq_len * 2] = mPositionIds.mW[i];
|
||||
}
|
||||
}
|
||||
if (mTalker) {
|
||||
mTalker->setPostionIds(mPositionIds);
|
||||
|
|
|
|||
Loading…
Reference in New Issue