Merge branch 'master' into feature/rvv-opt

This commit is contained in:
ihb2032 2025-08-06 15:02:23 +08:00
commit 4cf16d6761
66 changed files with 2316 additions and 1607 deletions

View File

@ -236,7 +236,7 @@ option(MNN_OPENGL "Enable OpenGL" OFF)
option(MNN_VULKAN "Enable Vulkan" OFF)
option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON)
option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF)
option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF)
option(MNN_KLEIDIAI "Enable KLEIDIAI" ON)
option(MNN_ONEDNN "Enable oneDNN" OFF)
option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON)
option(MNN_AVX512 "Enable AVX512" OFF)

View File

@ -11,19 +11,13 @@
## News 🔥
- [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) !
- [2025/06/11] New App MNN-TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN-TaoAvatar](./apps/Android/MnnTaoAvatar/README.md)
<p align="center">
<img width="20%" alt="Icon" src="https://meta.alicdn.com/data/mnn/avatar/avatar_demo.gif" style="margin: 0 10px;">
</p>
- [2025/05/30] MNN Chat app support DeepSeek-R1-0528-Qwen3,Qwen3-30B-A3B, SmoVLM and FastVLM [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases).
- [2025/05/12] android app support qwen2.5 omni 3b and 7b [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases).
<p align="center">
<img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_home_new.jpg" style="margin: 0 10px;">
<img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_sound_new.jpg" style="margin: 0 10px;">
<img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_image_new.jpg" style="margin: 0 10px;">
</p>
<details>
<summary> History News </summary>

View File

@ -3,6 +3,8 @@
[Download](#releases) [下载](./README_CN.md#releases)
[GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release)
[iOS App](../../iOS/MNNLLMChat/README.md)
## Introduction
@ -59,6 +61,14 @@ This is our full multimodal language model (LLM) Android app
```
# Releases
## Version 0.6.8
+ Click here to [download](https://meta.alicdn.com/data/mnn/mnn_chat_0_6_8.apk)
+ add new models: SmolLM3-3B、gemma-3-1b
+ support penalty sampler in mixed sampler mode.
+ can switch models in chat screen.
+ can update models when the remote models changed.
+ fix download source for huggingface.
+ Support Realtime voice call with ASR and TTS
## Version 0.5.1.2
+ Click here to [download](https://meta.alicdn.com/data/mnn/mnn_chat_0_5_1_2.apk)
+ fix huggingface download error

View File

@ -53,6 +53,13 @@
```
# Releases
+ 点击这里 [下载](https://meta.alicdn.com/data/mnn/mnn_chat_0_6_8.apk)
+ 新增模型支持 :现已支持 SmolLM3-3B 和 gemma-3-1b 模型。
+ 采样器功能增强 :在混合采样器模式中新增对 penalty sampler 的支持,提升生成质量与多样性。
+ 模型切换优化 :在聊天界面中支持 实时切换模型 ,提升使用灵活性。
+ 模型热更新支持 :当远程模型发生变化时,无需删除原有模型即可完成更新,避免重复下载。
+ 下载源优化 :优化了 HuggingFace 模型的下载源,提升下载速度与稳定性。
+ 新增语音通话功能 :支持 实时语音对话 集成语音识别ASR与语音合成TTS功能带来更丰富的交互体验。
## Version 0.5.1.2
+ 点击这里 [下载](https://meta.alicdn.com/data/mnn/mnn_chat_0_5_1_2.apk)

View File

@ -1 +1,5 @@
/build
/build
# Native libraries (downloaded at build time)
src/main/jniLibs/arm64-v8a/libsherpa-mnn-jni.so
src/main/jniLibs/arm64-v8a/libsherpa-mnn-jni.zip

View File

@ -5,6 +5,49 @@ plugins {
}
task downloadAndUnzipNativeLibs {
group = 'Pre-build'
description = 'Downloads and unzips libsherpa-mnn-jni native library from CDN.'
def nativeLibsUrl = 'https://meta.alicdn.com/data/mnn/libs/libsherpa-mnn-jni-16k.zip'
def zipFileName = 'libsherpa-mnn-jni-16k.zip'
def outputDir = file('src/main/jniLibs/arm64-v8a')
def downloadedZip = new File(project.buildDir, zipFileName)
def checkFile = new File(outputDir, 'libsherpa-mnn-jni.so')
inputs.property('url', nativeLibsUrl)
outputs.file(checkFile)
doLast {
println "-> Executing downloadAndUnzipNativeLibs task..."
println " Downloading from ${nativeLibsUrl}"
// Create output directory if it doesn't exist
outputDir.mkdirs()
ant.get(src: nativeLibsUrl, dest: downloadedZip)
if (!downloadedZip.exists()) {
throw new GradleException("Download failed: ${downloadedZip} not found.")
}
println " Download complete."
println " Unzipping ${downloadedZip.name} to ${outputDir}..."
copy {
from(zipTree(downloadedZip))
into(outputDir)
}
println " Unzip complete."
downloadedZip.delete()
}
onlyIf {
println "-> Checking if native libs exist... [Exists: ${checkFile.exists()}]"
return !checkFile.exists()
}
}
preBuild.dependsOn downloadAndUnzipNativeLibs
android {
namespace 'com.alibaba.mnnllm.android'
compileSdk 35
@ -16,8 +59,8 @@ android {
applicationId "com.alibaba.mnnllm.android"
minSdk 26
targetSdk 35
versionCode 606
versionName "0.6.6"
versionCode 608
versionName "0.6.8"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
externalNativeBuild {
@ -56,18 +99,22 @@ android {
signingConfigs {
release {
storeFile file(System.getenv("KEYSTORE_FILE"))
storePassword System.getenv("KEYSTORE_PASSWORD")
keyAlias System.getenv("KEY_ALIAS")
keyPassword System.getenv("KEY_PASSWORD")
def keystoreFile = System.getenv("KEYSTORE_FILE")
if (keystoreFile) {
storeFile file(keystoreFile)
storePassword System.getenv("KEYSTORE_PASSWORD")
keyAlias System.getenv("KEY_ALIAS")
keyPassword System.getenv("KEY_PASSWORD")
}
}
}
buildTypes {
release {
signingConfig signingConfigs.release
if (System.getenv("KEYSTORE_FILE")) {
signingConfig signingConfigs.release
}
applicationIdSuffix ".release"
}
debug {
versionNameSuffix ".gp"
}
}
}

View File

@ -1,3 +1,3 @@
<resources>
<string name="app_name">MNN Chat-dev</string>
<string name="app_name">MNN Chat</string>
</resources>

View File

@ -1,3 +1,3 @@
<resources>
<string name="app_name">MNN Chat-dev</string>
<string name="app_name">MNN Chat</string>
</resources>

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,97 @@
// Created by ruoyi.sjd on 2025/1/27.
// Copyright (c) 2024 Alibaba Group Holding Limited All rights reserved.
package com.alibaba.mls.api.download
import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
/**
* Centralized coroutine management for download operations
* Provides configurable scope and dispatcher to avoid threading issues
*
* Usage:
* ```
* // Configure in Application.onCreate()
* DownloadCoroutineManager.configureDispatcher(Dispatchers.IO) // or Dispatchers.Default
*
* // In downloaders
* DownloadCoroutineManager.launchDownload {
* // download logic
* }
* ```
*/
object DownloadCoroutineManager {
private const val TAG = "DownloadCoroutineManager"
/**
* Configurable dispatcher for download operations
* Default to Dispatchers.Default to avoid IO dispatcher blocking issues
*/
var downloadDispatcher: CoroutineDispatcher = Dispatchers.Default
private set
/**
* Coroutine scope for download operations
*/
private val _downloadScope by lazy {
CoroutineScope(downloadDispatcher + SupervisorJob())
}
/**
* Configure the dispatcher used for download operations
* @param dispatcher The coroutine dispatcher to use
*/
fun configureDispatcher(dispatcher: CoroutineDispatcher) {
Log.d(TAG, "Configuring download dispatcher to: $dispatcher")
downloadDispatcher = dispatcher
}
/**
* Initialize with IO dispatcher (call this when system is stable)
*/
fun initializeWithIO() {
configureDispatcher(Dispatchers.IO)
}
/**
* Initialize with Default dispatcher (call this when IO dispatcher has issues)
*/
fun initializeWithDefault() {
configureDispatcher(Dispatchers.Default)
}
/**
* Launch a download coroutine with proper error handling
* @param block The suspend function to execute
*/
fun launchDownload(block: suspend CoroutineScope.() -> Unit) {
_downloadScope.launch(downloadDispatcher, block = block)
}
/**
* Get a scope with download dispatcher for manual coroutine management
*/
fun getDownloadScope(): CoroutineScope {
return _downloadScope
}
/**
* Reset to default configuration (for testing or recovery)
*/
fun resetToDefault() {
Log.d(TAG, "Resetting to default dispatcher (Dispatchers.Default)")
downloadDispatcher = Dispatchers.Default
}
/**
* Get current dispatcher info for debugging
*/
fun getCurrentDispatcherInfo(): String {
return "Current download dispatcher: $downloadDispatcher"
}
}

View File

@ -9,6 +9,7 @@ data class DownloadInfo(
var totalSize: Long = 0,
var speedInfo: String = "",
var errorMessage: String? = null,
var errorException: Exception? = null,
var lastLogTime: Long = 0,
var lastProgressUpdateTime: Long = 0,
var progressStage: String = "",

View File

@ -26,6 +26,7 @@ object DownloadPersistentData {
const val METADATA_KEY: String = "meta_data"
const val SIZE_TOTAL_KEY: String = "size_total"
const val SIZE_SAVED_KEY: String = "size_saved"
const val SIZE_MARKET_TOTAL_KEY: String = "size_market_total"
const val DOWNLOADED_TIME_KEY: String = "downloaded_time"
// Create preference keys with modelId
@ -38,6 +39,9 @@ object DownloadPersistentData {
private fun createMetaDataKey(modelId: String): Preferences.Key<String> =
stringPreferencesKey("${METADATA_KEY}_$modelId")
private fun createSizeMarketTotalKey(modelId: String): Preferences.Key<Long> =
longPreferencesKey("${SIZE_MARKET_TOTAL_KEY}_$modelId")
private fun createDownloadedTimeKey(modelId: String): Preferences.Key<Long> =
longPreferencesKey("${DOWNLOADED_TIME_KEY}_$modelId")
@ -204,6 +208,35 @@ object DownloadPersistentData {
return dataStoreValue ?: 0L
}
fun saveMarketSizeTotal(context: Context, modelId: String, total: Long) {
runBlocking { saveMarketSizeTotalSuspend(context, modelId, total) }
}
suspend fun saveMarketSizeTotalSuspend(context: Context, modelId: String, total: Long) {
val normalizedModelId = ModelUtils.safeModelId(modelId)
val key = createSizeMarketTotalKey(normalizedModelId)
context.downloadDataStore.edit { preferences ->
preferences[key] = total
}
}
fun getMarketSizeTotal(context: Context, modelId: String): Long {
return runBlocking { getMarketSizeTotalSuspend(context, modelId) }
}
suspend fun getMarketSizeTotalSuspend(context: Context, modelId: String): Long {
val normalizedModelId = ModelUtils.safeModelId(modelId)
val key = createSizeMarketTotalKey(normalizedModelId)
// Try to read from DataStore
val dataStoreValue = context.downloadDataStore.data
.map { preferences -> preferences[key] }
.first()
return dataStoreValue ?: 0L
}
// Private migration helpers
private fun deleteSharedPrefsFileIfEmpty(context: Context, preferencesName: String) {

View File

@ -24,8 +24,6 @@ import com.alibaba.mnnllm.android.utils.CurrentActivityTracker
import com.alibaba.mnnllm.android.utils.FileUtils
import com.alibaba.mnnllm.android.utils.MmapUtils.clearMmapCache
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.MainScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.util.concurrent.ConcurrentHashMap
@ -42,7 +40,7 @@ class LoggingDownloadListener : DownloadListener {
}
override fun onDownloadProgress(modelId: String, downloadInfo: DownloadInfo) {
Log.v(ModelDownloadManager.TAG, "onDownloadProgress: $modelId, progress: ${downloadInfo.progress}, state: ${downloadInfo.downloadState}, speed: ${downloadInfo.speedInfo}")
Log.v(ModelDownloadManager.TAG, "onDownloadProgress: $modelId, progress: ${downloadInfo.progress}, state: ${downloadInfo.downloadState}, speed: ${downloadInfo.speedInfo} stage: ${downloadInfo.progressStage}")
}
override fun onDownloadFinished(modelId: String, path: String) {
@ -253,7 +251,6 @@ class ModelDownloadManager private constructor(context: Context) {
fun getDownloadInfo(modelId: String): DownloadInfo {
Log.d(TAG, "getDownloadInfo: $modelId totalSize: ${DownloadPersistentData.getDownloadSizeTotal(ApplicationProvider.get(), modelId)}" +
" progress: ${getRealDownloadSize(modelId)}")
val source = ModelUtils.getSource(modelId)!!
if (!downloadInfoMap.containsKey(modelId)) {
val downloadInfo = DownloadInfo()
if (getDownloadedFile(modelId) != null) {
@ -277,7 +274,11 @@ class ModelDownloadManager private constructor(context: Context) {
}
downloadInfoMap[modelId] = downloadInfo
if (downloadInfo.totalSize < 100) {
getRepoSize(modelId, source)
val marketSize = DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
Log.d(TAG, "getMarketSize for ${modelId} size: $marketSize")
if (marketSize > 0) {
downloadInfo.totalSize = marketSize
}
}
}
val downloadInfo = downloadInfoMap[modelId]!!
@ -305,26 +306,6 @@ class ModelDownloadManager private constructor(context: Context) {
return savedSize
}
private fun getRepoSize(
modelId: String,
source: Any,
modelName: String = modelId
) {
MainScope().launch {
val downloader = when (source) {
is String -> getDownloaderForSource(source)
is ModelSources.ModelSourceType -> getDownloaderForSource(source)
else -> throw IllegalArgumentException("Unsupported source type: ${source::class.java.name}")
}
val repoSize = downloader.getRepoSize(modelId)
if (repoSize > 0 && DownloadPersistentData.getDownloadSizeTotal(ApplicationProvider.get(), modelName) <= 0L) {
DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelName, repoSize)
downloadInfoMap[modelName]?.totalSize = repoSize
listeners.forEach { it.onDownloadTotalSize(modelName, repoSize) }
}
}
}
private fun setDownloadFinished(modelId: String, path: String) {
val downloadInfo = downloadInfoMap[modelId] ?: return
downloadInfo.downloadState = DownloadState.COMPLETED
@ -422,6 +403,7 @@ class ModelDownloadManager private constructor(context: Context) {
val info = downloadInfoMap.getOrPut(modelId) { DownloadInfo() }
info.downloadState = DownloadState.FAILED
info.errorMessage = e.message
info.errorException = e
listeners.forEach {
Log.d(TAG, "[setDownloadFailed] Notifying listener: ${it.javaClass.simpleName}")
it.onDownloadFailed(modelId, e)
@ -470,9 +452,6 @@ class ModelDownloadManager private constructor(context: Context) {
savedSize: Long,
totalSize: Long
) {
if (totalSize <= 0) {
return
}
val downloadInfo = downloadInfoMap.getOrPut(modelId) { DownloadInfo() }
val currentTime = System.currentTimeMillis()
if (stage == downloadInfo.progressStage && currentTime - downloadInfo.lastProgressUpdateTime < 1000) {
@ -486,11 +465,13 @@ class ModelDownloadManager private constructor(context: Context) {
if (downloadInfo.savedSize <= 0) {
downloadInfo.savedSize = savedSize
}
downloadInfo.progress = savedSize.toDouble() / totalSize
DownloadPersistentData.saveDownloadSizeSaved(ApplicationProvider.get(), modelId, savedSize)
DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelId, totalSize)
calculateDownloadSpeed(downloadInfo, savedSize)
Log.v(TAG, "[updateDownloadingProgress] Notifying ${listeners.size} listeners for $modelId")
if (totalSize > 0) {
downloadInfo.progress = savedSize.toDouble() / totalSize
DownloadPersistentData.saveDownloadSizeSaved(ApplicationProvider.get(), modelId, savedSize)
DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelId, totalSize)
calculateDownloadSpeed(downloadInfo, savedSize)
}
Log.v(TAG, "[updateDownloadingProgress] Notifying ${listeners.size} listeners for $modelId stage: $stage")
listeners.forEach { it.onDownloadProgress(modelId, downloadInfo) }
}

View File

@ -25,8 +25,7 @@ import com.alibaba.mls.api.hf.HfFileMetadata
import com.alibaba.mls.api.hf.HfRepoInfo
import com.alibaba.mnnllm.android.chat.model.ChatDataManager
import com.alibaba.mnnllm.android.utils.TimeUtils
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import com.alibaba.mls.api.download.DownloadCoroutineManager
import kotlinx.coroutines.withContext
import okhttp3.OkHttpClient
import java.io.File
@ -60,10 +59,11 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
* Unified method to fetch repo information from HuggingFace API
* @param modelId the model ID to fetch info for
* @param calculateSize whether to calculate repo size (requires additional network requests)
* @return HfRepoInfo object or null if failed
* @return HfRepoInfo object
* @throws FileDownloadException if failed to fetch repo info
*/
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): HfRepoInfo? {
return withContext(Dispatchers.IO) {
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): HfRepoInfo {
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val hfModelId = hfModelId(modelId)
val response = getHfApiClient().apiService.getRepoInfo(hfModelId, "main")?.execute()
@ -79,11 +79,16 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
callback?.onRepoInfo(modelId, lastModified, repoSize)
repoInfo
} else {
null
val errorMsg = if (response?.isSuccessful == false) {
"API request failed with code ${response.code()}: ${response.message()}"
} else {
"API response was null or empty"
}
throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg")
}
}.getOrElse { exception ->
Log.e(TAG, "Failed to fetch repo info for $modelId", exception)
null
throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}")
}
}
}
@ -92,7 +97,7 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
* Calculate total size of all files in the repo
*/
private suspend fun calculateRepoSize(hfRepoInfo: HfRepoInfo): Long {
return withContext(Dispatchers.IO) {
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val metaList = requestMetaDataList(hfRepoInfo)
metaList.sumOf { it?.size ?: 0L }
@ -101,20 +106,22 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
}
override fun download(modelId: String) {
executor!!.submit {
runBlocking {
DownloadCoroutineManager.launchDownload {
try {
val repoInfo = fetchRepoInfo(modelId)
if (repoInfo != null) {
downloadHfRepo(repoInfo)
} else {
callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId"))
}
downloadHfRepo(repoInfo)
} catch (e: FileDownloadException) {
callback?.onDownloadFailed(modelId, e)
}
}
}
override suspend fun checkUpdate(modelId: String) {
fetchRepoInfo(modelId)
try {
fetchRepoInfo(modelId)
} catch (e: FileDownloadException) {
Log.e(TAG, "Failed to check update for $modelId", e)
}
}
fun downloadHfRepo(hfRepoInfo: HfRepoInfo) {
@ -127,17 +134,23 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?,
}
override suspend fun getRepoSize(modelId: String): Long {
return withContext(Dispatchers.IO) {
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
if (repoInfo != null) {
// Size was already calculated in fetchRepoInfo, but we can also calculate it directly
calculateRepoSize(repoInfo)
} else {
0L
}
// Size was already calculated in fetchRepoInfo, but we can also calculate it directly
calculateRepoSize(repoInfo)
}.getOrElse { exception ->
Log.e(TAG, "Failed to get repo size for $modelId", exception)
// Try to get file_size from saved market data as fallback
try {
val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
if (marketSize > 0) {
Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize")
return@withContext marketSize
}
} catch (e: Exception) {
Log.w(TAG, "Failed to get saved market size for $modelId", e)
}
0L
}
}

View File

@ -22,8 +22,7 @@ import com.alibaba.mls.api.ml.FileInfo
import com.alibaba.mls.api.ml.MlApiClient
import com.alibaba.mls.api.ml.MlRepoInfo
import com.alibaba.mnnllm.android.model.ModelUtils
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import com.alibaba.mls.api.download.DownloadCoroutineManager
import kotlinx.coroutines.withContext
import java.io.File
import com.alibaba.mls.api.ml.MlRepoData
@ -43,15 +42,16 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
* Unified method to fetch repo information from Modelers API
* @param modelId the model ID to fetch info for
* @param calculateSize whether to calculate repo size (requires additional network requests)
* @return MlRepoInfo object or null if failed
* @return MlRepoInfo object
* @throws FileDownloadException if failed to fetch repo info
*/
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MlRepoInfo? {
return withContext(Dispatchers.IO) {
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MlRepoInfo {
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val modelersId = ModelUtils.getRepositoryPath(modelId)
val split = modelersId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
if (split.size != 2) {
return@runCatching null
throw FileDownloadException("Invalid model ID format for $modelId, expected format: owner/repo")
}
val response = mlApiClient.apiService.getModelFiles(split[0], split[1], "").execute()
@ -84,24 +84,33 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
callback?.onRepoInfo(modelId, lastModified, repoSize)
repoInfo
} else {
null
val errorMsg = if (!response.isSuccessful) {
"API request failed with code ${response.code()}: ${response.message()}"
} else {
"API response was null or empty"
}
throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg")
}
}.getOrElse { exception ->
Log.e(TAG, "Failed to fetch repo info for $modelId", exception)
null
throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}")
}
}
}
override fun download(modelId: String) {
Log.d(TAG, "start download ${modelId}")
DownloadExecutor.executeScope.launch {
DownloadCoroutineManager.launchDownload {
downloadRepo(modelId)
}
}
override suspend fun checkUpdate(modelId: String) {
fetchRepoInfo(modelId, false)
try {
fetchRepoInfo(modelId, false)
} catch (e: FileDownloadException) {
Log.e(TAG, "Failed to check update for $modelId", e)
}
}
override fun getDownloadPath(modelId: String): File {
@ -125,16 +134,22 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
}
override suspend fun getRepoSize(modelId: String): Long {
return withContext(Dispatchers.IO) {
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
if (repoInfo != null) {
repoInfo.data.tree.filter { it.type != "dir" }.sumOf { it.size }
} else {
0L
}
repoInfo.data.tree.filter { it.type != "dir" }.sumOf { it.size }
}.getOrElse { exception ->
Log.e(TAG, "Failed to get repo size for $modelId", exception)
// Try to get file_size from saved market data as fallback
try {
val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
if (marketSize > 0) {
Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize")
return@withContext marketSize
}
} catch (e: Exception) {
Log.w(TAG, "Failed to get saved market size for $modelId", e)
}
0L
}
}
@ -149,15 +164,16 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?,
return null
}
val modelInfo = fetchRepoInfo(modelId, calculateSize = true)
if (modelInfo != null) {
try {
val modelInfo = fetchRepoInfo(modelId, calculateSize = true)
callback?.onDownloadTaskAdded()
downloadMlRepoInner(modelId, modelersId, modelInfo)
callback?.onDownloadTaskRemoved()
} else {
callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId"))
return modelInfo
} catch (e: FileDownloadException) {
callback?.onDownloadFailed(modelId, e)
return null
}
return modelInfo
}
private fun getAllFiles(owner: String, repo: String, path: String, allFiles: MutableList<FileInfo>) {

View File

@ -6,7 +6,6 @@ import android.util.Log
import com.alibaba.mls.api.ApplicationProvider
import com.alibaba.mls.api.FileDownloadException
import com.alibaba.mls.api.hf.HfFileMetadata
import com.alibaba.mls.api.download.DownloadExecutor.Companion.executor
import com.alibaba.mls.api.download.DownloadFileUtils.createSymlink
import com.alibaba.mls.api.download.DownloadFileUtils.deleteDirectoryRecursively
import com.alibaba.mls.api.download.DownloadFileUtils.getLastFileName
@ -22,9 +21,10 @@ import com.alibaba.mls.api.ms.MsApiClient
import com.alibaba.mls.api.ms.MsRepoInfo
import com.alibaba.mnnllm.android.chat.model.ChatDataManager
import com.alibaba.mnnllm.android.model.ModelUtils
import com.alibaba.mnnllm.android.modelmarket.ModelMarketItem
import com.alibaba.mnnllm.android.modelmarket.ModelRepository
import com.alibaba.mnnllm.android.utils.TimeUtils
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import com.alibaba.mls.api.download.DownloadCoroutineManager
import kotlinx.coroutines.withContext
import retrofit2.Call
import retrofit2.Callback
@ -44,15 +44,18 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
* Unified method to fetch repo information from ModelScope API
* @param modelId the model ID to fetch info for
* @param calculateSize whether to calculate repo size (not needed for ModelScope as size is included)
* @return MsRepoInfo object or null if failed
* @return MsRepoInfo object
* @throws FileDownloadException if failed to fetch repo info
*/
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MsRepoInfo? {
return withContext(Dispatchers.IO) {
private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MsRepoInfo {
Log.d(TAG, "fetchRepoInfo called for modelId: $modelId, calculateSize: $calculateSize")
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val msModelId = ModelUtils.getRepositoryPath(modelId)
val split = msModelId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
if (split.size != 2) {
return@runCatching null
throw FileDownloadException("Invalid model ID format for $modelId, expected format: owner/repo")
}
val response = msApiClient.apiService.getModelFiles(split[0], split[1]).execute()
@ -65,25 +68,34 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
callback?.onRepoInfo(modelId, lastModified, repoSize)
repoInfo
} else {
null
val errorMsg = if (!response.isSuccessful) {
"API request failed with code ${response.code()}: ${response.message()}"
} else {
"API response was null or empty"
}
throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg")
}
}.getOrElse { exception ->
Log.e(TAG, "Failed to fetch repo info for $modelId", exception)
null
throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}")
}
}
}
override fun download(modelId: String) {
executor!!.submit {
kotlinx.coroutines.runBlocking {
downloadMsRepo(modelId)
}
Log.d(TAG, "MsModelDownloader download: $modelId")
DownloadCoroutineManager.launchDownload {
downloadMsRepo(modelId)
}
}
override suspend fun checkUpdate(modelId: String) {
fetchRepoInfo(modelId)
try {
fetchRepoInfo(modelId)
} catch (e: FileDownloadException) {
Log.e(TAG, "Failed to check update for $modelId", e)
}
}
override fun getDownloadPath(modelId: String): File {
@ -107,16 +119,22 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
}
override suspend fun getRepoSize(modelId: String): Long {
return withContext(Dispatchers.IO) {
return withContext(DownloadCoroutineManager.downloadDispatcher) {
runCatching {
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
if (repoInfo != null) {
repoInfo.Data?.Files?.filter { it.Type != "tree" }?.sumOf { it.Size } ?: 0L
} else {
0L
}
repoInfo.Data?.Files?.filter { it.Type != "tree" }?.sumOf { it.Size } ?: 0L
}.getOrElse { exception ->
Log.e(TAG, "Failed to get repo size for $modelId", exception)
// Try to get file_size from saved market data as fallback
try {
val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId)
if (marketSize > 0) {
Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize")
return@withContext marketSize
}
} catch (e: Exception) {
Log.w(TAG, "Failed to get saved market size for $modelId", e)
}
0L
}
}
@ -124,20 +142,20 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?,
private suspend fun downloadMsRepo(modelId: String) {
val modelScopeId = ModelUtils.getRepositoryPath(modelId)
Log.d(TAG, "MsModelDownloader downloadMsRepo: $modelId modelScopeId : $modelScopeId")
val split = modelScopeId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
if (split.size != 2) {
callback?.onDownloadFailed(modelId, FileDownloadException("getRepoInfoFailed modelId format error: $modelId"))
return
}
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
if (repoInfo != null) {
Log.d(TAG, "downloadMsRepoInner executor")
try {
val repoInfo = fetchRepoInfo(modelId, calculateSize = true)
Log.d(TAG, "downloadMsRepo repoInfo: $repoInfo")
callback?.onDownloadTaskAdded()
downloadMsRepoInner(modelId, modelScopeId, repoInfo)
callback?.onDownloadTaskRemoved()
} else {
callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId"))
} catch (e: FileDownloadException) {
callback?.onDownloadFailed(modelId, e)
}
}

View File

@ -69,7 +69,7 @@ class BenchmarkStateMachine {
fun isValidTransition(from: BenchmarkState, to: BenchmarkState): Boolean {
val validTransitions = when (from) {
BenchmarkState.IDLE -> listOf(BenchmarkState.LOADING_MODELS)
BenchmarkState.LOADING_MODELS -> listOf(BenchmarkState.READY, BenchmarkState.ERROR_MODEL_NOT_FOUND)
BenchmarkState.LOADING_MODELS -> listOf(BenchmarkState.READY, BenchmarkState.ERROR_MODEL_NOT_FOUND, BenchmarkState.ERROR)
BenchmarkState.READY -> listOf(BenchmarkState.INITIALIZING, BenchmarkState.LOADING_MODELS)
BenchmarkState.INITIALIZING -> listOf(BenchmarkState.RUNNING, BenchmarkState.ERROR)
BenchmarkState.RUNNING -> listOf(BenchmarkState.STOPPING, BenchmarkState.COMPLETED, BenchmarkState.ERROR)

View File

@ -356,6 +356,9 @@ class MarketItemHolder(
R.id.menu_open_model_card -> {
openModelCard(itemView.context, modelMarketItem)
}
R.id.menu_copy_error_info -> {
copyErrorInfoToClipboard(modelMarketItemWrapper)
}
}
true
}
@ -388,6 +391,10 @@ class MarketItemHolder(
// Model card: always visible
menu.findItem(R.id.menu_open_model_card).isVisible = false
// Copy error info: visible only for failed downloads
menu.findItem(R.id.menu_copy_error_info).isVisible =
downloadState == DownloadState.FAILED
}
private fun handleSettingsMenu(modelMarketItem: ModelMarketItem) {
@ -409,6 +416,41 @@ class MarketItemHolder(
private fun openModelCard(context: android.content.Context, modelMarketItem: ModelMarketItem) {
}
private fun copyErrorInfoToClipboard(modelMarketItemWrapper: ModelMarketItemWrapper) {
val context = itemView.context
val downloadInfo = modelMarketItemWrapper.downloadInfo
val modelMarketItem = modelMarketItemWrapper.modelMarketItem
if (downloadInfo.downloadState != DownloadState.FAILED) {
return
}
val errorMessage = downloadInfo.errorMessage ?: "Unknown error"
val modelInfo = "Model: ${modelMarketItem.modelName} (${modelMarketItem.modelId})"
// Build error info with stack trace if available
val errorInfoBuilder = StringBuilder().apply {
appendLine(modelInfo)
appendLine("Error: $errorMessage")
appendLine("Time: ${java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss", java.util.Locale.getDefault()).format(java.util.Date())}")
// Add stack trace if exception is available
downloadInfo.errorException?.let { exception ->
appendLine()
appendLine("Stack Trace:")
appendLine(android.util.Log.getStackTraceString(exception))
}
}
val errorInfo = errorInfoBuilder.toString()
val clipboard = context.getSystemService(android.content.Context.CLIPBOARD_SERVICE) as android.content.ClipboardManager
val clip = android.content.ClipData.newPlainText("Error Info", errorInfo)
clipboard.setPrimaryClip(clip)
Toast.makeText(context, context.getString(R.string.error_info_copied), Toast.LENGTH_SHORT).show()
}
companion object {
const val TAG = "ModelItemHolder"
}

View File

@ -11,6 +11,7 @@ data class ModelMarketItem(
val categories: List<String>,
val sources: Map<String, String>,
val description: String? = null,
@SerializedName("file_size") val fileSize: Long = 0L, // File size in bytes from model_market.json
var currentSource: String = "", // e.g. "modelscope", "huggingface"
var currentRepoPath: String = "", // e.g. "MNN/Qwen-1.8B-Chat-Int4"
var modelId: String = "" // e.g. "ModelScope/MNN/Qwen-1.8B-Chat-Int4"

View File

@ -3,6 +3,7 @@ package com.alibaba.mnnllm.android.modelmarket
import android.content.Context
import android.util.Log
import androidx.preference.PreferenceManager
import com.alibaba.mls.api.download.DownloadPersistentData
import com.alibaba.mnnllm.android.mainsettings.MainSettings
import com.google.gson.Gson
import kotlinx.coroutines.Dispatchers
@ -233,6 +234,16 @@ class ModelRepository(private val context: Context) {
item.currentSource = selectedSource
item.currentRepoPath = item.sources[selectedSource]!!
item.modelId = "$selectedSource/${item.sources[selectedSource]!!}"
// Save file_size to DownloadPersistentData if available
if (item.fileSize > 0) {
try {
DownloadPersistentData.saveMarketSizeTotalSuspend(context, item.modelId, item.fileSize)
Log.d(TAG, "Saved market size for ${item.modelId}: ${item.fileSize}")
} catch (e: Exception) {
Log.w(TAG, "Failed to save market size for ${item.modelId}", e)
}
}
}
}

View File

@ -19,4 +19,7 @@
<item
android:id="@+id/menu_open_model_card"
android:title="@string/open_model_card" />
<item
android:id="@+id/menu_copy_error_info"
android:title="@string/menu_copy_error_info" />
</menu>

View File

@ -485,4 +485,6 @@
<string name="allow_network_market_data">允许联网获取模型市场数据</string>
<string name="enable_network_delay">启用网络延时3秒</string>
<string name="update_available">有新版本</string>
<string name="menu_copy_error_info">复制错误信息</string>
<string name="error_info_copied">错误信息已复制到剪贴板</string>
</resources>

View File

@ -488,4 +488,6 @@
<string name="allow_network_market_data">Allow network to fetch model market data</string>
<string name="enable_network_delay">Enable network delay (3 seconds)</string>
<string name="update_available">Update Available</string>
<string name="menu_copy_error_info">Copy Error Info</string>
<string name="error_info_copied">Error information copied to clipboard</string>
</resources>

View File

@ -0,0 +1,126 @@
import json
import argparse
from huggingface_hub import HfApi
from huggingface_hub.utils import HfHubHTTPError
def get_repo_size_in_bytes(repo_id: str) -> int:
"""
Calculates the total size of all files in a Hugging Face model repository.
Args:
repo_id (str): The ID of the model repository, e.g., 'google-bert/bert-base-uncased'.
Returns:
int: The total size of the files in bytes. Returns -1 if the repository
cannot be found or an error occurs.
"""
# Initialize the HfApi client
api = HfApi()
total_size = 0
try:
# Fetch model information, including metadata for each file
print(f"Fetching metadata for repository: '{repo_id}'...")
model_info = api.model_info(repo_id=repo_id, files_metadata=True)
# Sum the size of each file in the repository
for file in model_info.siblings:
if file.size is not None:
total_size += file.size
print(f"Successfully calculated size for '{repo_id}': {total_size} bytes")
return total_size
except HfHubHTTPError as e:
# Handle cases where the repository is not found or other HTTP errors
print(f"Error: Could not retrieve info for repository '{repo_id}'. It might not exist or be private.")
print(f"Details: {e}")
return -1
except Exception as e:
# Handle other potential exceptions
print(f"An unexpected error occurred while processing '{repo_id}': {e}")
return -1
def process_model_list(models: list):
"""
Iterates through a list of model objects and adds the 'file_size' field.
Args:
models (list): A list of dictionaries, where each dictionary represents a model.
"""
if not models:
return # Do nothing if the list is empty
for model in models:
# Check if the model has a HuggingFace source
if 'sources' in model and 'HuggingFace' in model['sources']:
repo_id = model['sources']['HuggingFace']
if repo_id:
# Get the repository size and add it to the model object
file_size = get_repo_size_in_bytes(repo_id)
model['file_size'] = file_size
else:
# Handle cases where the HuggingFace repo_id is empty
model['file_size'] = -1
print(f"Warning: Empty HuggingFace repo_id for model '{model.get('modelName', 'N/A')}'.")
else:
# If no HuggingFace source, you can decide what to do.
# Here we skip adding the field, or you could add 'file_size': 0 or -1
print(f"Skipping model '{model.get('modelName', 'N/A')}' as it has no HuggingFace source.")
def main():
"""
Main function to parse arguments, read the input JSON, process it,
and write to the output JSON file.
"""
# Set up argument parser for command-line interface
parser = argparse.ArgumentParser(
description="Process a market config JSON file to add Hugging Face model sizes."
)
parser.add_argument(
"-i", "--input",
required=True,
help="Path to the input JSON file."
)
parser.add_argument(
"-o", "--output",
required=True,
help="Path to the output JSON file."
)
args = parser.parse_args()
# Read the input JSON file
try:
with open(args.input, 'r', encoding='utf-8') as f:
data = json.load(f)
except FileNotFoundError:
print(f"Error: Input file not found at '{args.input}'")
return
except json.JSONDecodeError:
print(f"Error: Could not decode JSON from the input file '{args.input}'.")
return
# Define the keys that contain lists of models to process
model_list_keys = ['models', 'tts_models', 'asr_models']
# Process each list of models
for key in model_list_keys:
if key in data and isinstance(data[key], list):
print(f"\n--- Processing models in '{key}' ---")
process_model_list(data[key])
else:
print(f"\n--- No models found in '{key}', skipping ---")
# Write the updated data to the output JSON file
try:
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print(f"\nProcessing complete. Output successfully written to '{args.output}'")
except IOError as e:
print(f"Error: Could not write to output file '{args.output}'. Details: {e}")
if __name__ == '__main__':
main()

View File

@ -102,7 +102,7 @@ final class LLMChatInteractor: ChatInteractorProtocol {
// PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") {
var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1]
if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true,
if self?.modelInfo.tags.contains("Think") == true,
let text = self?.processor.process(progress: message.text) {
updateLastMsg?.text = text
} else {

View File

@ -15,7 +15,7 @@ struct ChatHistoryItemView: View {
if let lastMessage = getLastNonEmptyMessage() {
Text(String(lastMessage.content.prefix(200)))
.lineLimit(1)
.lineLimit(3)
.font(.system(size: 15, weight: .medium))
.foregroundColor(.primary)
}

View File

@ -709,9 +709,13 @@ bool remove_directory_safely(const std::string& path) {
}
std::string inputStr = [input UTF8String];
#ifdef DEBUG
if (inputStr == "benchmark") {
[blockSelf performBenchmarkWithOutput:&os];
} else {
#else
{
#endif
// Get initial context state for performance measurement
auto context = blockSelf->_llm->getContext();
int initial_prompt_len = context->prompt_len;

View File

@ -1308,7 +1308,14 @@
}
},
"Yes" : {
"localizations" : {
"zh-Hans" : {
"stringUnit" : {
"state" : "translated",
"value" : "是"
}
}
}
},
"搜索本地模型..." : {
"localizations" : {

View File

@ -125,6 +125,7 @@ class BenchmarkService: ObservableObject {
/// Releases the current model and frees associated resources
func releaseModel() {
llmEngine?.cancelInference()
llmEngine = nil
currentModelId = nil
}

View File

@ -34,6 +34,7 @@ class BenchmarkViewModel: ObservableObject {
@Published var startButtonText = String(localized: "Start Test")
@Published var isStartButtonEnabled = true
@Published var showStopConfirmation = false
// MARK: - Private Properties
@ -105,6 +106,7 @@ class BenchmarkViewModel: ObservableObject {
/// Handles benchmark stop confirmation
func onStopBenchmarkTapped() {
showStopConfirmation = false
stopBenchmark()
}
@ -119,8 +121,8 @@ class BenchmarkViewModel: ObservableObject {
showResults = false
hideStatus()
// Release model to free memory
benchmarkService.releaseModel()
// Clean up resources when deleting results
cleanupBenchmarkResources()
}
/// Placeholder for future result submission functionality
@ -200,14 +202,16 @@ class BenchmarkViewModel: ObservableObject {
private func stopBenchmark() {
updateStatus("Stopping benchmark...")
benchmarkService.stopBenchmark()
MemoryMonitor.shared.stop()
cleanupBenchmarkResources()
}
// MARK: - UI State Management
/// Updates UI state when benchmark starts
private func onBenchmarkStarted() {
isRunning = true
isStartButtonEnabled = true
startButtonText = String(localized: "Stop Test")
showProgressBar = true
showResults = false
updateStatus("Initializing benchmark...")
@ -215,11 +219,22 @@ class BenchmarkViewModel: ObservableObject {
/// Resets UI to initial state
private func resetUIState() {
isRunning = false
isStartButtonEnabled = true
startButtonText = String(localized: "Start Test")
showProgressBar = false
hideStatus()
showResults = false
cleanupBenchmarkResources()
}
/// Cleans up benchmark resources including memory monitoring and model
private func cleanupBenchmarkResources() {
MemoryMonitor.shared.stop()
MemoryMonitor.shared.reset()
// Release model to free memory
benchmarkService.releaseModel()
}
/// Updates status message display
@ -238,9 +253,9 @@ class BenchmarkViewModel: ObservableObject {
showError = true
}
/// Placeholder for stop confirmation alert (handled in View)
/// Shows stop confirmation alert
private func showStopConfirmationAlert() {
// This will be handled in the View with an alert
showStopConfirmation = true
}
/// Formats progress messages with appropriate status text based on progress type
@ -309,11 +324,14 @@ extension BenchmarkViewModel: BenchmarkCallback {
benchmarkResults = results
showResults = true
// Only stop memory monitoring if benchmark is no longer running (all tests completed)
if !isRunning {
// Stop memory monitoring
MemoryMonitor.shared.stop()
}
// Update UI state to reflect completion
isRunning = false
isStartButtonEnabled = true
startButtonText = String(localized: "Start Test")
showProgressBar = false
// Clean up resources after benchmark completion
cleanupBenchmarkResources()
// Always hide status after processing results
hideStatus()
@ -324,9 +342,16 @@ extension BenchmarkViewModel: BenchmarkCallback {
/// Handles benchmark errors with user-friendly error messages
func onBenchmarkError(_ errorCode: Int, _ message: String) {
let errorCodeName = BenchmarkErrorCode(rawValue: errorCode)?.description ?? "Unknown"
showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)")
// Check if this is a user-initiated stop - don't show error dialog
if errorCode == BenchmarkErrorCode.benchmarkStopped.rawValue {
print("BenchmarkViewModel: Benchmark stopped by user (\(errorCode)): \(message)")
} else {
showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)")
print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)")
}
resetUIState()
print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)")
}
}

View File

@ -13,7 +13,6 @@ import SwiftUI
*/
struct ModelSelectionCard: View {
@ObservedObject var viewModel: BenchmarkViewModel
@Binding var showStopConfirmation: Bool
var body: some View {
VStack(alignment: .leading, spacing: 16) {
@ -140,11 +139,7 @@ struct ModelSelectionCard: View {
private var startStopButton: some View {
Button(action: {
if viewModel.startButtonText.contains("Stop") {
showStopConfirmation = true
} else {
viewModel.onStartBenchmarkTapped()
}
viewModel.onStartBenchmarkTapped()
}) {
HStack(spacing: 12) {
ZStack {
@ -152,12 +147,12 @@ struct ModelSelectionCard: View {
.fill(Color.white.opacity(0.2))
.frame(width: 32, height: 32)
if viewModel.isRunning && viewModel.startButtonText.contains("Stop") {
if viewModel.isRunning {
ProgressView()
.progressViewStyle(CircularProgressViewStyle(tint: .white))
.scaleEffect(0.7)
} else {
Image(systemName: viewModel.startButtonText.contains("Stop") ? "stop.fill" : "play.fill")
Image(systemName: viewModel.isRunning ? "stop.fill" : "play.fill")
.font(.system(size: 16, weight: .bold))
.foregroundColor(.white)
}
@ -169,7 +164,7 @@ struct ModelSelectionCard: View {
Spacer()
if !viewModel.startButtonText.contains("Stop") {
if !viewModel.isRunning {
Image(systemName: "arrow.right")
.font(.system(size: 16, weight: .semibold))
.foregroundColor(.white.opacity(0.8))
@ -182,7 +177,7 @@ struct ModelSelectionCard: View {
RoundedRectangle(cornerRadius: 16)
.fill(
viewModel.isStartButtonEnabled ?
(viewModel.startButtonText.contains("Stop") ?
(viewModel.isRunning ?
LinearGradient(
colors: [Color.benchmarkError, Color.benchmarkError.opacity(0.8)],
startPoint: .leading,
@ -234,8 +229,7 @@ struct ModelSelectionCard: View {
#Preview {
ModelSelectionCard(
viewModel: BenchmarkViewModel(),
showStopConfirmation: .constant(false)
viewModel: BenchmarkViewModel()
)
.padding()
}

View File

@ -13,7 +13,6 @@ import SwiftUI
*/
struct BenchmarkView: View {
@StateObject private var viewModel = BenchmarkViewModel()
@State private var showStopConfirmation = false
var body: some View {
ZStack {
@ -21,8 +20,7 @@ struct BenchmarkView: View {
VStack(spacing: 24) {
// Model Selection Section
ModelSelectionCard(
viewModel: viewModel,
showStopConfirmation: $showStopConfirmation
viewModel: viewModel
)
// Progress Section
@ -55,7 +53,7 @@ struct BenchmarkView: View {
.padding(.vertical, 16)
}
}
.alert("Stop Benchmark", isPresented: $showStopConfirmation) {
.alert("Stop Benchmark", isPresented: $viewModel.showStopConfirmation) {
Button("Yes", role: .destructive) {
viewModel.onStopBenchmarkTapped()
}
@ -68,11 +66,7 @@ struct BenchmarkView: View {
} message: {
Text(viewModel.errorMessage)
}
.onReceive(viewModel.$isRunning) { isRunning in
if isRunning && viewModel.startButtonText.contains("Stop") {
showStopConfirmation = false
}
}
}
}

View File

@ -32,7 +32,9 @@ struct LocalModelListView: View {
viewModel.selectModel(model)
}) {
LocalModelRowView(model: model)
.contentShape(Rectangle())
}
.buttonStyle(PlainButtonStyle())
.listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear)
.swipeActions(edge: .trailing, allowsFullSwipe: false) {
SwipeActionsView(model: model, viewModel: viewModel)

View File

@ -59,6 +59,10 @@ struct LocalModelRowView: View {
}
}
}
Spacer()
}
.padding(.vertical, 8)
.frame(maxWidth: .infinity, alignment: .leading)
}
}

View File

@ -142,7 +142,10 @@ struct MainTabView: View {
.edgesIgnoringSafeArea(.all)
}
.onChange(of: showHistory) { oldValue, newValue in
if !newValue {
if newValue {
// Refresh chat history when opening the side menu
histories = ChatHistoryManager.shared.getAllHistory()
} else {
DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) {
withAnimation {
showHistoryButton = true
@ -191,6 +194,9 @@ struct MainTabView: View {
modelListViewModel.recordModelUsage(modelName: model.modelName)
}
// Refresh chat history when returning from chat
histories = ChatHistoryManager.shared.getAllHistory()
// Clear selections
modelListViewModel.selectedModel = nil
selectedHistory = nil

View File

@ -67,7 +67,14 @@ struct ModelInfo: Codable {
}
let sourceKey = ModelSourceManager.shared.selectedSource.rawValue
return sources[sourceKey] ?? "taobao-mnn/\(modelName)"
let baseId = sources[sourceKey] ?? "taobao-mnn/\(modelName)"
// Add vendor prefix to ensure uniqueness for local models
if vendor == "Local" {
return "local/\(modelName)"
}
return baseId
}
var localizedTags: [String] {

View File

@ -72,7 +72,7 @@ class ModelListViewModel: ObservableObject {
if !foundModelFiles.isEmpty {
// Check if we have a complete model (at least config.json)
if foundModelFiles.contains("config.json") {
let modelName = "Qwen3-0.6B-MNN"
let modelName = "Qwen3-0.6B-MNN-Inside"
let localModel = ModelInfo(
modelName: modelName,
tags: [NSLocalizedString("tag.deepThinking", comment: "Deep thinking tag for local model"),
@ -143,8 +143,10 @@ class ModelListViewModel: ObservableObject {
let itemConfigPath = (itemPath as NSString).appendingPathComponent("config.json")
if fileManager.fileExists(atPath: itemConfigPath) {
// Use custom name for Qwen3-0.6B-MNN to avoid conflicts
let modelName = item == "Qwen3-0.6B-MNN" ? "Qwen3-0.6B-MNN-Inside" : item
let localModel = ModelInfo(
modelName: item,
modelName: modelName,
tags: ["local", "bundled"],
categories: ["Local Models"],
vendor: "Local",
@ -153,7 +155,7 @@ class ModelListViewModel: ObservableObject {
)
localModels.append(localModel)
ModelStorageManager.shared.markModelAsDownloaded(item)
ModelStorageManager.shared.markModelAsDownloaded(modelName)
}
}
}
@ -175,12 +177,15 @@ class ModelListViewModel: ObservableObject {
var fetchedModels = info.models
// Add LocalModel folder models
// Add LocalModel folder models, avoiding duplicates
let localModels = await loadLocalModels()
fetchedModels.append(contentsOf: localModels)
let existingModelNames = Set(fetchedModels.map { $0.modelName })
let uniqueLocalModels = localModels.filter { !existingModelNames.contains($0.modelName) }
fetchedModels.append(contentsOf: uniqueLocalModels)
filterDiffusionModels(fetchedModels: &fetchedModels)
loadCachedSizes(for: &fetchedModels)
syncDownloadStatus(for: &fetchedModels)
sortModels(fetchedModels: &fetchedModels)
self.models = fetchedModels
@ -203,6 +208,18 @@ class ModelListViewModel: ObservableObject {
}
}
private func syncDownloadStatus(for models: inout [ModelInfo]) {
for i in 0..<models.count {
let isDownloaded = ModelStorageManager.shared.isModelDownloaded(models[i].modelName)
models[i].isDownloaded = isDownloaded
// Also sync last used date
if let lastUsed = ModelStorageManager.shared.getLastUsed(for: models[i].modelName) {
models[i].lastUsedAt = lastUsed
}
}
}
private func fetchModelSizes(for models: [ModelInfo]) async {
await withTaskGroup(of: Void.self) { group in
for (_, model) in models.enumerated() {

View File

@ -117,7 +117,7 @@ static inline uint64_t getTimeInUs() {
}
std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward = MNN_FORWARD_CPU, bool only_inference = true,
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false) {
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false, bool enableKleidiAI=false) {
auto revertor = std::unique_ptr<Revert>(new Revert(model.model_file.c_str()));
if (testQuantModel) {
revertor->initialize(0, sparseBlockOC, false, true);
@ -130,6 +130,7 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize), MNN::Interpreter::destroy);
revertor.reset();
net->setSessionMode(MNN::Interpreter::Session_Release);
net->setSessionHint(MNN::Interpreter::HintMode::CPU_ENABLE_KLEIDIAI, enableKleidiAI);
MNN::ScheduleConfig config;
config.numThread = numberThread;
config.type = static_cast<MNNForwardType>(forward);
@ -392,8 +393,9 @@ int main(int argc, const char* argv[]) {
int precision = 2;
float sparsity = 0.0f;
int sparseBlockOC = 1;
bool enableKleidiAI = false;
if (argc <= 2) {
std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel]" << std::endl;
std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel] [enableKleidiAI]" << std::endl;
return 1;
}
if (argc >= 3) {
@ -420,8 +422,11 @@ int main(int argc, const char* argv[]) {
if(argc >= 10) {
testQuantizedModel = atoi(argv[9]);
}
if (argc >= 11) {
enableKleidiAI = atoi(argv[10]) > 0 ? true : false;
}
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << std::endl;
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << " enableKleidiAI=" << enableKleidiAI << std::endl;
std::vector<Model> models = findModelFiles(argv[1]);
std::cout << "--------> Benchmarking... loop = " << argv[2] << ", warmup = " << warmup << std::endl;
@ -441,10 +446,10 @@ int main(int argc, const char* argv[]) {
}
for (auto& m : models) {
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false);
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false, enableKleidiAI);
displayStats(m.name.c_str(), costs, false);
if (testQuantizedModel) {
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true);
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true, enableKleidiAI);
displayStats(m.name, costs, 1);
}
}

View File

@ -98,5 +98,5 @@ MNN使用CMake构建项目CMake中的宏定义列表如下
| MNN_SUPPORT_TRANSFORMER_FUSE | 是否支持Fuse Transformer相关OP实现默认为 `OFF` |
| MNN_BUILD_LLM | 是否构建基于MNN的llm库和demo默认为`OFF` |
| MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo需要打开MNN_BUILD_OPENCV和MNN_IMGCODECS宏使用 默认为`OFF` |
| MNN_KLEIDIAI | 是否集成ARM的klediAI加速库【目前处于实验状态只能跑对称量化的LLM模型】默认为`OFF` |
| MNN_KLEIDIAI | 是否集成ARM的klediAI加速库,默认为`ON` |
| MNN_USE_RVV | 是否启用RISC-V向量扩展支持默认为`OFF` |

View File

@ -252,7 +252,10 @@ public:
CPU_CORE_IDS = 14,
// set CPU threads to use when supports Arm sme2
CPU_SME2_INSTRUCTIONS = 15
CPU_SME2_INSTRUCTIONS = 15,
// Enable KleidiAI
CPU_ENABLE_KLEIDIAI = 16
};
enum ExternalPathType {

View File

@ -21,10 +21,6 @@
#include "ThreadPool.hpp"
#endif
#ifdef MNN_KLEIDIAI_ENABLED
#include "arm/mnn_kleidiai.h"
#endif
namespace MNN {
class WorkerThread;
class CPURuntime : public Runtime {

View File

@ -19,8 +19,7 @@ if (MNN_CPU_WEIGHT_DEQUANT_GEMM)
FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/normal_memory/*.[sS])
endif()
if (MNN_KLEIDIAI)
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
if (MNN_KLEIDIAI AND CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64")
# Disable the KleidiAI tests
set(KLEIDIAI_BUILD_TESTS OFF)
# Fetch KleidiAI sources:
@ -121,6 +120,9 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?")
endif()
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64")
message(STATUS "Enabling AArch64 Assemblies")
if (MNN_KLEIDIAI)
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
endif()
if (MNN_SME2)
add_definitions(-DMNN_SME2)
FILE(GLOB MNN_SME2_AArch64_SRC ${MNN_SME2_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/sme2_asm/*.[sS])

View File

@ -23,6 +23,7 @@
#include "core/OpCommonUtils.hpp"
#include "backend/cpu/OneDNNConvolution.hpp"
#include "backend/cpu/compute/ConvInt8TiledExecutor.hpp"
#ifdef MNN_KLEIDIAI_ENABLED
#include "backend/cpu/compute/KleidiAIConvInt8.hpp"
#include "backend/cpu/compute/KleidiAIConvolution.hpp"
@ -31,6 +32,92 @@
namespace MNN {
#ifdef MNN_KLEIDIAI_ENABLED
static Execution* _createKleidiAIUnit(const Tensor* input, const Tensor* output, Backend* backend, const Op* op,
const float* originWeight, size_t originWeightSize, const float* bias,
size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo,
bool supportSparse, bool lowMemory) {
auto cpuBackend = (CPUBackend*)backend;
auto conv2d = op->main_as_Convolution2D();
auto common = conv2d->common();
bool fastWay = common->kernelY() == 1 && common->kernelX() == 1 && output->width() == input->width() &&
output->height() == input->height() && common->strideX() == 1 && common->strideY() == 1;
#ifdef MNN_LOW_MEMORY
if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) {
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
do {
if (!weightQuantInfo->canUseInt4) {
break;
}
auto convOp = op->main_as_Convolution2D();
auto core = static_cast<CPUBackend*>(backend)->functions();
int oc = convOp->common()->outputCount();
int ic = convOp->common()->inputCount();
int blockNum = 1;
int dequantCnt = weightQuantInfo->alphaSize;
if (weightQuantInfo->asymmetric) {
dequantCnt /= 2;
}
blockNum = dequantCnt / oc;
bool bAsym = weightQuantInfo->asymmetric;
size_t blkSize = blockNum == 1 ? 0 : ic / blockNum;
KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes);
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if (!kai.canAccelerate(accelType, convOp->common())) {
break;
}
if (!kai.isLoaded(accelType)) {
kai.setLoaded(accelType);
kai.printInfo(accelType);
}
return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum);
} while (0);
}
// Have not supported the quantized weight.
return nullptr;
}
#else
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize,
weightQuantInfo);
}
// Do nothing and fallback.
return nullptr;
}
#endif
// This is different with original impl. It's a corresponding impl for strassen,
// which is called when built without MNN_REDUCE_SIZE. But for KleidiAI,
// need not to care about this.
if (fastWay && cpuBackend->functions()->matmulBytes == 0) {
auto bytes = cpuBackend->functions()->bytes;
auto accelType = (bytes == 2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32;
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if (kai.canAccelerate(accelType)) {
return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize);
}
}
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize,
weightQuantInfo);
}
return nullptr;
}
#endif //MNN_KLEIDIAI_ENABLED
static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend* backend,
const Op* op, const float* originWeight, size_t originWeightSize, const float* bias, size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo, bool supportSparse, bool lowMemory) {
auto cpuBackend = (CPUBackend*)backend;
@ -48,47 +135,24 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
}
}
#endif
#ifdef MNN_KLEIDIAI_ENABLED
if (cpuBackend->getRuntime()->hint().enableKleidiAI) {
auto execution = _createKleidiAIUnit(input, output, backend, op, originWeight, originWeightSize, bias, biasSize,
weightQuantInfo, supportSparse, lowMemory);
if (execution) {
return execution;
}
}
#endif //MNN_KLEIDIAI_ENABLED
bool fastWay = common->kernelY() == 1 && common->kernelX() == 1
&& output->width() == input->width() && output->height() == input->height()
&& common->strideX() == 1 && common->strideY() == 1;
#ifdef MNN_LOW_MEMORY
if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) {
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
#ifdef MNN_KLEIDIAI_ENABLED
do {
if (!weightQuantInfo->canUseInt4) {
break;
}
auto convOp = op->main_as_Convolution2D();
auto core = static_cast<CPUBackend*>(backend)->functions();
int oc = convOp->common()->outputCount();
int ic = convOp->common()->inputCount();
int blockNum = 1;
int dequantCnt = weightQuantInfo->alphaSize;
if (weightQuantInfo->asymmetric) {
dequantCnt /= 2;
}
blockNum = dequantCnt / oc;
bool bAsym = weightQuantInfo->asymmetric;
size_t blkSize = blockNum == 1 ? 0 : ic / blockNum;
KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes);
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if(!kai.isLoaded(accelType)) {
kai.setLoaded(accelType);
kai.printInfo(accelType);
}
if(!kai.canAccelerate(accelType, convOp->common())){
break;
}
return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum);
} while (0);
#endif
return new DenseConvInt8TiledExecutor(backend, op, weightQuantInfo, true);
} else {
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
@ -96,37 +160,16 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
}
#else
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
#ifdef MNN_KLEIDIAI_ENABLED
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
}
#endif
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
}
#endif
#ifndef MNN_REDUCE_SIZE
if (fastWay && cpuBackend->functions()->matmulBytes == 0) {
#ifdef MNN_KLEIDIAI_ENABLED
auto bytes = cpuBackend->functions()->bytes;
auto accelType = (bytes==2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32;
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if (kai.canAccelerate(accelType)){
return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize);
}
#endif //MNN_KLEIDIAI_ENABLED
return new Convolution1x1Strassen(common, backend, originWeight, originWeightSize, bias, biasSize);
}
#endif
#ifdef MNN_KLEIDIAI_ENABLED
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
}
#endif
if (cpuBackend->getRuntime()->hint().winogradMemoryUsed == 0 || (!ConvolutionWinogradBridge::canUseWinograd(common))) {
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, nullptr);
}

View File

@ -303,4 +303,4 @@ ErrorCode KleidiAIConvInt8::onExecute(const std::vector<Tensor*>& inputs, const
}
} // namespace MNN
#endif //MNN_KLEIDIAI_ENABLED
#endif //MNN_KLEIDIAI_ENABLED

View File

@ -6,10 +6,8 @@
#ifndef KleidiAIConvInt8_hpp
#define KleidiAIConvInt8_hpp
#ifdef MNN_KLEIDIAI_ENABLED
#include "backend/cpu/CPUConvolution.hpp"
#include "Int8FunctionsOpt.h"
#include "CommonOptFunction.h"
#include "backend/cpu/arm/mnn_kleidiai.h"
namespace MNN {
class KleidiAIConvInt8 : public CPUConvolution {
@ -31,5 +29,4 @@ private:
};
} // namespace MNN
#endif // MNN_KLEIDIAI_ENABLED
#endif /* KleidiAIConvInt8_hpp */
#endif /* KleidiAIConvInt8_hpp */

View File

@ -3,19 +3,15 @@
//
// SPDX-License-Identifier: Apache-2.0
//
#ifdef MNN_KLEIDIAI_ENABLED
#include "KleidiAIConvolution.hpp"
#include <string.h>
#include "core/BufferAllocator.hpp"
#include "backend/cpu/CPUBackend.hpp"
#include "core/Concurrency.h"
#include "core/TensorUtils.hpp"
#include "backend/cpu/CPUTensorConvert.hpp"
namespace MNN {
#ifndef MNN_REDUCE_SIZE
KleidiAIConvolution::KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight,
size_t originWeightSize, const float *bias, size_t biasSize)
: CPUConvolution(common, b) {
@ -228,7 +224,5 @@ ErrorCode KleidiAIConvolution::onExecute(const std::vector<Tensor *> &inputs, co
return NO_ERROR;
}
#endif
} // namespace MNN
#endif //MNN_KLEIDIAI_ENABLED
#endif //MNN_KLEIDIAI_ENABLED

View File

@ -6,12 +6,9 @@
#ifndef KleidiAIConvolution_hpp
#define KleidiAIConvolution_hpp
#ifdef MNN_KLEIDIAI_ENABLED
#include <functional>
#include "backend/cpu/CPUConvolution.hpp"
#include "backend/cpu/arm/mnn_kleidiai.h"
namespace MNN {
#ifndef MNN_REDUCE_SIZE
class KleidiAIConvolution : public CPUConvolution{
public:
KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight, size_t originWeightSize, const float *bias, size_t biasSize);
@ -30,8 +27,5 @@ class KleidiAIConvolution : public CPUConvolution{
KleidiAI::AccelType mAccelType = KleidiAI::AccelType::ACC_TYPE_NUMBER;
std::vector<float> mPostParameters;
};
#endif //MNN_KLEIDIAI_ENABLED
} // namespace MNN
#endif
#endif /* KleidiAIConvolution_hpp */

View File

@ -1,4 +1,4 @@
#if MNN_KLEIDIAI_ENABLED
#ifdef MNN_KLEIDIAI_ENABLED
#include "KleidiAIDenseConvolution.hpp"
#include <numeric>
@ -9,6 +9,7 @@
#include "backend/cpu/CPUTensorConvert.hpp"
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
#include "core/Concurrency.h"
#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
@ -304,10 +305,15 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
.dilatedWidth = mCommon->dilateX(),
};
mFunction.second = [=](int tid) {
int threadNum = static_cast<CPUBackend*>(backend())->threadNumber();
mFunction.second = [=](int tId) {
// Convert NC4HW4 to NHWC
auto inputShape = input->shape(); // TODO check for NC4HW4, should be the NCHW
CPUTensorConverter::convert(input, &mInputNHWC, core);
// CPUTensorConverter::convert(input, &mInputNHWC, core);
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
CPUTensorConverter::convert(input, &mInputNHWC, core, tId, threadNum);
};
MNN_CONCURRENCY_END();
// Lhs packing
if (bytes == 4) {
int blockSize = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme();
@ -348,7 +354,11 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
}
// Convert NHWC to NC4HW4
CPUTensorConverter::convert(&mOutputNHWC, output, core);
// CPUTensorConverter::convert(&mOutputNHWC, output, core);
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
CPUTensorConverter::convert(&mOutputNHWC, output, core, tId, threadNum);
};
MNN_CONCURRENCY_END();
};
return NO_ERROR;
}
@ -359,4 +369,4 @@ ErrorCode KleidiAIDenseConvolutionImpl::onExecute(const std::vector<Tensor*>& in
return NO_ERROR;
}
} // namespace MNN
#endif
#endif //MNN_KLEIDIAI_ENABLED

View File

@ -1,5 +1,3 @@
#if MNN_KLEIDIAI_ENABLED
#ifndef KleidiAIDenseConvolution_hpp
#define KleidiAIDenseConvolution_hpp
@ -242,4 +240,3 @@ private:
} // namespace MNN
#endif /* KleidiAIDenseConvolution_hpp */
#endif

View File

@ -62,6 +62,8 @@ struct RuntimeHint {
// whether to use Arm sme2 cores when threads>1
bool useArmSme2Cores = true;
bool enableKleidiAI = false;
// Use CPU Ids
std::vector<int> cpuIds;
};

View File

@ -109,6 +109,9 @@ void Session::ModeGroup::setHint(Interpreter::HintMode hint, int value) {
case Interpreter::HintMode::INIT_THREAD_NUMBER:
runtimeHint.initThreadNumber = value;
break;
case Interpreter::HintMode::CPU_ENABLE_KLEIDIAI:
runtimeHint.enableKleidiAI = value > 0 ? true : false;
break;
default:
break;
}

View File

@ -19,10 +19,6 @@
#undef CONSTANT
#endif // CONSTANT
#ifdef MNN_KLEIDIAI_ENABLED
#include "../backend/cpu/arm/mnn_kleidiai.h"
#endif
namespace MNN {
struct TensorArrayAttr {
// array size is dynamic or not

View File

@ -92,3 +92,7 @@ MNNForwardType getCurrentType() {
return attr->firstType;
}
std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor() {
auto attr = MNN::Express::ExecutorScope::Current()->getAttr();
return MNN::Express::Executor::newExecutor(getCurrentType(), attr->config, attr->numThread);
}

View File

@ -104,6 +104,8 @@ inline float keepFP32Precision(float fp32Value) {
}
MNNForwardType getCurrentType();
std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor();
using ConvertFP32 = float(*)(float fp32Value);
const static std::vector<ConvertFP32> FP32Converter = {

View File

@ -49,6 +49,12 @@ public:
for (int i = 0; i < channel * height * width; ++i){
inputData[i] = (rand() % 10) * 0.1;
}
MNN::BackendConfig config;
config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal;
config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal;
std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4));
ExecutorScope scope(executor);
auto net = _createModel();
auto x = _Input({1, channel, height, width}, NCHW, halide_type_of<float>());
@ -65,11 +71,7 @@ public:
// clone model
MNN::BackendConfig config;
config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal;
config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal;
std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4));
ExecutorScope scope(executor);
std::unique_ptr<Module> tempModule(Module::clone(net.get()));
auto xClone = _Input({1, channel, height, width}, NCHW, halide_type_of<float>());

View File

@ -14,12 +14,16 @@
#include <MNN/expr/Module.hpp>
#include "MNNTestSuite.h"
#include "MNN_generated.h"
#include "TestUtils.h"
using namespace MNN;
using namespace MNN::Express;
class GatherExprTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::unique_ptr<MNN::OpT> gatherOp(new MNN::OpT);
gatherOp->type = MNN::OpType_GatherND;
auto parameter = _Input({2, 2}, NHWC, halide_type_of<int32_t>());
@ -224,7 +228,8 @@ public:
class GatherNdReComputeTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
const float inpudata[] = {-1.0, -2.0, 3.0, 4.0};
const int indices_data[] = {0, 0, 1, 1};
auto params = _Const(inpudata, {2, 2}, NHWC, halide_type_of<float>());

View File

@ -22,6 +22,8 @@ public:
MNN_ERROR("Currently don't test not cpu mmap\n");
return true;
}
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto x = _Input({1, 3, 224, 224}, NC4HW4, halide_type_of<float>());
x->setName("x");
auto y = _Conv(1.0f, 0.01f, x, {3, 16}, {5, 5});

View File

@ -11,6 +11,7 @@
#include <random>
#include "MNNTestSuite.h"
#include "MNN_generated.h"
#include "TestUtils.h"
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Module.hpp>
@ -66,6 +67,8 @@ static void _originMatMul(float* C, const float* A, const float* B, int e, int l
class MatMulTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
int e = 5, h = 4, l = 6;
if (true) {
// Test MatMul

View File

@ -2,6 +2,7 @@
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Module.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
using namespace MNN;
using namespace MNN::Express;
@ -15,6 +16,8 @@ public:
return summer;
}
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<VARP> empty;
// Make Net
auto x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>());

View File

@ -177,6 +177,8 @@ MNNTestSuiteRegister(ModuleTest, "expr/ModuleTest");
class ModuleWrongInputTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<int8_t> buffer;
// construct
{
@ -244,6 +246,8 @@ MNNTestSuiteRegister(ModuleWrongInputTest, "expr/ModuleWrongInputTest");
class RefTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<int8_t> buffer;
// construct
{
@ -318,6 +322,8 @@ public:
}
}
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<int8_t> buffer;
#ifdef MNN_REDUCE_SIZE
return true;
@ -1039,6 +1045,8 @@ MNNTestSuiteRegister(ConstMemoryReplaceTest, "expr/ConstMemoryReplaceTest");
class MutlThreadConstReplaceTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto func = [precision](VARP y, int thread) {
flatbuffers::FlatBufferBuilder builderOutput(1024);
{
@ -1499,6 +1507,8 @@ MNNTestSuiteRegister(ExecutorResetLoadModuleTest, "expr/ExecutorResetLoadModuleT
class SequenceForwardResizeTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
// Make Model include convolution in shape compute and content compute
auto x = _Input({1, 3, 24, 24}, NCHW, halide_type_of<float>());
x->setName("x");
@ -1606,6 +1616,8 @@ MNNTestSuiteRegister(SequenceForwardResizeTest, "expr/SequenceForwardResizeTest"
class InputModuleTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto y = _mobileNetV1Expr(nullptr, false);
std::unique_ptr<MNN::NetT> net(new NetT);
Variable::save({y}, net.get());

View File

@ -35,6 +35,8 @@ static std::shared_ptr<Module> _createModel() {
class RasterOutputTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto net = _createModel();
auto x = _Input({1, 3, 224, 224}, NCHW, halide_type_of<int>());
auto y = _Transpose(x, {0, 1, 3, 2});

View File

@ -66,6 +66,11 @@ int main(int argc, char* argv[]) {
dynamicOption = atoi(argv[7]);
FUNC_PRINT(dynamicOption);
}
bool enableKleidiAI = false;
if (argc > 8) {
enableKleidiAI = atoi(argv[8]) > 0 ? true : false;
FUNC_PRINT(enableKleidiAI);
}
auto exe = MNN::Express::Executor::newExecutor(type, config, thread);
if (exe == nullptr) {
MNN_ERROR("Can't create executor with type:%d, exit!\n", type);
@ -76,6 +81,7 @@ int main(int argc, char* argv[]) {
// set hint
MNN::RuntimeHint hint;
hint.dynamicQuantOption = dynamicOption;
hint.enableKleidiAI = enableKleidiAI;
scope.Current()->getRuntime().second->setRuntimeHint(hint);
MNNTestSuite::get()->pStaus.memory = memory;
MNNTestSuite::get()->pStaus.precision = precision;

View File

@ -202,6 +202,8 @@ class BinaryBroadcastTest : public MNNTestCase {
virtual ~BinaryBroadcastTest() = default;
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
bool resultNCHW = testDimensionFormat(NCHW, precision);
bool resultNHWC = testDimensionFormat(NHWC, precision);

View File

@ -17,6 +17,8 @@ class ReverseTest : public MNNTestCase {
public:
virtual ~ReverseTest() = default;
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::shared_ptr<MNN::Express::Module> net;
{
auto input = _Input({3, 2, 3}, NCHW);

View File

@ -109,7 +109,7 @@ static inline std::vector<int> parseIntList(const std::string& str, char delim)
int main(int argc, char *argv[]) {
if (argc < 3) {
MNN_PRINT("=======================================================================================================================================\n");
MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds]\n");
MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds] [enableKleidiAI]\n");
MNN_PRINT("=======================================================================================================================================\n");
return 0;
}
@ -247,11 +247,16 @@ int main(int argc, char *argv[]) {
for (auto id : cpuIds) {
MNN_PRINT("%d ", id);
}
bool enableKleidiAI = false;
if (argc > 10) {
enableKleidiAI = atoi(argv[10]) > 0 ? true : false;
}
MNN_PRINT("\n");
FUNC_PRINT(precision);
FUNC_PRINT(memory);
FUNC_PRINT(power);
FUNC_PRINT_ALL(cacheFileName, s);
FUNC_PRINT(enableKleidiAI);
// create session
MNN::ScheduleConfig config;
config.type = type;
@ -320,6 +325,10 @@ int main(int argc, char *argv[]) {
rtmgr->setHint(Interpreter::DYNAMIC_QUANT_OPTIONS, 2);
}
if (enableKleidiAI) {
rtmgr->setHint(Interpreter::CPU_ENABLE_KLEIDIAI, true);
}
// rtmgr->setHint(Interpreter::CPU_SME2_INSTRUCTIONS, false);
if (runMask & 2048) {

View File

@ -645,18 +645,21 @@ VARP Omni::gen_position_ids(int seq_len) {
positionIds = _Input({3, seq_len}, NCHW, halide_type_of<int>());
}
auto ptr = positionIds->writeMap<int>();
if (mContext->gen_seq_len > 0) {
for (int i=0; i<seq_len; ++i) {
auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
ptr[i + 0] = pos;
ptr[i + seq_len] = pos;
ptr[i + seq_len * 2] = pos;
}
if (seq_len == 1) {
ptr[0] = mContext->gen_seq_len + mPositionIds.back();
ptr[1] = ptr[0];
ptr[2] = ptr[0];
} else {
for (int i = 0; i < seq_len; i++) {
ptr[i] = mPositionIds.mT[i];
ptr[i + seq_len] = mPositionIds.mH[i];
ptr[i + seq_len * 2] = mPositionIds.mW[i];
if (mPositionIds.mT.size() != seq_len) {
ptr[i] = i;
ptr[i + seq_len] = i;
ptr[i + seq_len * 2] = i;
} else {
ptr[i] = mPositionIds.mT[i];
ptr[i + seq_len] = mPositionIds.mH[i];
ptr[i + seq_len * 2] = mPositionIds.mW[i];
}
}
if (mTalker) {
mTalker->setPostionIds(mPositionIds);