mirror of https://github.com/alibaba/MNN.git
				
				
				
			Merge branch 'master' into feature/rvv-opt
This commit is contained in:
		
						commit
						4cf16d6761
					
				|  | @ -236,7 +236,7 @@ option(MNN_OPENGL "Enable OpenGL" OFF) | |||
| option(MNN_VULKAN "Enable Vulkan" OFF) | ||||
| option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) | ||||
| option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) | ||||
| option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF) | ||||
| option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) | ||||
| option(MNN_ONEDNN "Enable oneDNN" OFF) | ||||
| option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) | ||||
| option(MNN_AVX512 "Enable AVX512" OFF) | ||||
|  |  | |||
|  | @ -11,19 +11,13 @@ | |||
| 
 | ||||
| 
 | ||||
| ## News 🔥 | ||||
| - [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) ! | ||||
| - [2025/06/11] New App MNN-TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN-TaoAvatar](./apps/Android/MnnTaoAvatar/README.md)  | ||||
| <p align="center"> | ||||
|   <img width="20%" alt="Icon"  src="https://meta.alicdn.com/data/mnn/avatar/avatar_demo.gif" style="margin: 0 10px;"> | ||||
| </p> | ||||
| 
 | ||||
| - [2025/05/30] MNN Chat app support DeepSeek-R1-0528-Qwen3,Qwen3-30B-A3B, SmoVLM and FastVLM [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases). | ||||
| - [2025/05/12] android app support qwen2.5 omni 3b and 7b [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases). | ||||
| <p align="center"> | ||||
|   <img width="20%" alt="Icon"  src="./apps/Android/MnnLlmChat/assets/image_home_new.jpg" style="margin: 0 10px;"> | ||||
|   <img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_sound_new.jpg" style="margin: 0 10px;"> | ||||
|   <img width="20%" alt="Icon" src="./apps/Android/MnnLlmChat/assets/image_image_new.jpg" style="margin: 0 10px;"> | ||||
| </p> | ||||
| 
 | ||||
| 
 | ||||
| <details> | ||||
| <summary> History News </summary> | ||||
|  |  | |||
|  | @ -3,6 +3,8 @@ | |||
| 
 | ||||
| [Download](#releases)  [下载](./README_CN.md#releases) | ||||
| 
 | ||||
| [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) | ||||
| 
 | ||||
| [iOS App](../../iOS/MNNLLMChat/README.md) | ||||
| 
 | ||||
| ## Introduction | ||||
|  | @ -59,6 +61,14 @@ This is our full multimodal language model (LLM) Android app | |||
|   ``` | ||||
| 
 | ||||
| # Releases | ||||
| ## Version 0.6.8 | ||||
| + Click here to [download](https://meta.alicdn.com/data/mnn/mnn_chat_0_6_8.apk) | ||||
| + add new models:  SmolLM3-3B、gemma-3-1b  | ||||
| + support penalty sampler in mixed sampler mode. | ||||
| + can switch models in chat screen. | ||||
| + can update models when the remote models changed. | ||||
| + fix download source for huggingface. | ||||
| + Support  Realtime voice call with ASR and TTS | ||||
| ## Version 0.5.1.2 | ||||
| + Click here to [download](https://meta.alicdn.com/data/mnn/mnn_chat_0_5_1_2.apk) | ||||
| + fix huggingface download error | ||||
|  |  | |||
|  | @ -53,6 +53,13 @@ | |||
|   ``` | ||||
| 
 | ||||
| # Releases | ||||
| + 点击这里  [下载](https://meta.alicdn.com/data/mnn/mnn_chat_0_6_8.apk) | ||||
| + 新增模型支持 :现已支持 SmolLM3-3B 和 gemma-3-1b 模型。 | ||||
| + 采样器功能增强 :在混合采样器模式中新增对 penalty sampler 的支持,提升生成质量与多样性。 | ||||
| + 模型切换优化 :在聊天界面中支持 实时切换模型 ,提升使用灵活性。 | ||||
| + 模型热更新支持 :当远程模型发生变化时,无需删除原有模型即可完成更新,避免重复下载。 | ||||
| + 下载源优化 :优化了 HuggingFace 模型的下载源,提升下载速度与稳定性。 | ||||
| + 新增语音通话功能 :支持 实时语音对话 ,集成语音识别(ASR)与语音合成(TTS)功能,带来更丰富的交互体验。 | ||||
| 
 | ||||
| ## Version 0.5.1.2 | ||||
| + 点击这里  [下载](https://meta.alicdn.com/data/mnn/mnn_chat_0_5_1_2.apk) | ||||
|  |  | |||
|  | @ -1 +1,5 @@ | |||
| /build | ||||
| /build | ||||
| 
 | ||||
| # Native libraries (downloaded at build time) | ||||
| src/main/jniLibs/arm64-v8a/libsherpa-mnn-jni.so | ||||
| src/main/jniLibs/arm64-v8a/libsherpa-mnn-jni.zip | ||||
|  | @ -5,6 +5,49 @@ plugins { | |||
| 
 | ||||
| } | ||||
| 
 | ||||
| task downloadAndUnzipNativeLibs { | ||||
|     group = 'Pre-build' | ||||
|     description = 'Downloads and unzips libsherpa-mnn-jni native library from CDN.' | ||||
|     def nativeLibsUrl = 'https://meta.alicdn.com/data/mnn/libs/libsherpa-mnn-jni-16k.zip' | ||||
|     def zipFileName = 'libsherpa-mnn-jni-16k.zip' | ||||
|     def outputDir = file('src/main/jniLibs/arm64-v8a') | ||||
|     def downloadedZip = new File(project.buildDir, zipFileName) | ||||
|     def checkFile = new File(outputDir, 'libsherpa-mnn-jni.so') | ||||
|      | ||||
|     inputs.property('url', nativeLibsUrl) | ||||
|     outputs.file(checkFile) | ||||
|      | ||||
|     doLast { | ||||
|         println "-> Executing downloadAndUnzipNativeLibs task..." | ||||
|         println "   Downloading from ${nativeLibsUrl}" | ||||
|          | ||||
|         // Create output directory if it doesn't exist | ||||
|         outputDir.mkdirs() | ||||
|          | ||||
|         ant.get(src: nativeLibsUrl, dest: downloadedZip) | ||||
| 
 | ||||
|         if (!downloadedZip.exists()) { | ||||
|             throw new GradleException("Download failed: ${downloadedZip} not found.") | ||||
|         } | ||||
|         println "   Download complete." | ||||
|         println "   Unzipping ${downloadedZip.name} to ${outputDir}..." | ||||
|          | ||||
|         copy { | ||||
|             from(zipTree(downloadedZip)) | ||||
|             into(outputDir) | ||||
|         } | ||||
|         println "   Unzip complete." | ||||
|         downloadedZip.delete() | ||||
|     } | ||||
|      | ||||
|     onlyIf { | ||||
|         println "-> Checking if native libs exist... [Exists: ${checkFile.exists()}]" | ||||
|         return !checkFile.exists() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| preBuild.dependsOn downloadAndUnzipNativeLibs | ||||
| 
 | ||||
| android { | ||||
|     namespace 'com.alibaba.mnnllm.android' | ||||
|     compileSdk 35 | ||||
|  | @ -16,8 +59,8 @@ android { | |||
|         applicationId "com.alibaba.mnnllm.android" | ||||
|         minSdk 26 | ||||
|         targetSdk 35 | ||||
|         versionCode 606 | ||||
|         versionName "0.6.6" | ||||
|         versionCode 608 | ||||
|         versionName "0.6.8" | ||||
| 
 | ||||
|         testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" | ||||
|         externalNativeBuild { | ||||
|  | @ -56,18 +99,22 @@ android { | |||
| 
 | ||||
|         signingConfigs { | ||||
|             release { | ||||
|                 storeFile file(System.getenv("KEYSTORE_FILE")) | ||||
|                 storePassword System.getenv("KEYSTORE_PASSWORD") | ||||
|                 keyAlias System.getenv("KEY_ALIAS") | ||||
|                 keyPassword System.getenv("KEY_PASSWORD") | ||||
|                 def keystoreFile = System.getenv("KEYSTORE_FILE") | ||||
|                 if (keystoreFile) { | ||||
|                     storeFile file(keystoreFile) | ||||
|                     storePassword System.getenv("KEYSTORE_PASSWORD") | ||||
|                     keyAlias System.getenv("KEY_ALIAS") | ||||
|                     keyPassword System.getenv("KEY_PASSWORD") | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         buildTypes { | ||||
|             release { | ||||
|                 signingConfig signingConfigs.release | ||||
|                 if (System.getenv("KEYSTORE_FILE")) { | ||||
|                     signingConfig signingConfigs.release | ||||
|                 } | ||||
|                 applicationIdSuffix ".release" | ||||
|             } | ||||
|             debug { | ||||
|                 versionNameSuffix ".gp" | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  |  | |||
|  | @ -1,3 +1,3 @@ | |||
| <resources> | ||||
|     <string name="app_name">MNN Chat-dev</string> | ||||
|     <string name="app_name">MNN Chat</string> | ||||
| </resources>  | ||||
|  | @ -1,3 +1,3 @@ | |||
| <resources> | ||||
|     <string name="app_name">MNN Chat-dev</string> | ||||
|     <string name="app_name">MNN Chat</string> | ||||
| </resources>  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,97 @@ | |||
| // Created by ruoyi.sjd on 2025/1/27. | ||||
| // Copyright (c) 2024 Alibaba Group Holding Limited All rights reserved. | ||||
| 
 | ||||
| package com.alibaba.mls.api.download | ||||
| 
 | ||||
| import android.util.Log | ||||
| import kotlinx.coroutines.CoroutineDispatcher | ||||
| import kotlinx.coroutines.CoroutineScope | ||||
| import kotlinx.coroutines.Dispatchers | ||||
| import kotlinx.coroutines.SupervisorJob | ||||
| import kotlinx.coroutines.launch | ||||
| 
 | ||||
| /** | ||||
|  * Centralized coroutine management for download operations | ||||
|  * Provides configurable scope and dispatcher to avoid threading issues | ||||
|  *  | ||||
|  * Usage: | ||||
|  * ``` | ||||
|  * // Configure in Application.onCreate() | ||||
|  * DownloadCoroutineManager.configureDispatcher(Dispatchers.IO) // or Dispatchers.Default | ||||
|  *  | ||||
|  * // In downloaders | ||||
|  * DownloadCoroutineManager.launchDownload { | ||||
|  *     // download logic | ||||
|  * } | ||||
|  * ``` | ||||
|  */ | ||||
| object DownloadCoroutineManager { | ||||
|     private const val TAG = "DownloadCoroutineManager" | ||||
|      | ||||
|     /** | ||||
|      * Configurable dispatcher for download operations | ||||
|      * Default to Dispatchers.Default to avoid IO dispatcher blocking issues | ||||
|      */ | ||||
|     var downloadDispatcher: CoroutineDispatcher = Dispatchers.Default | ||||
|         private set | ||||
|      | ||||
|     /** | ||||
|      * Coroutine scope for download operations | ||||
|      */ | ||||
|     private val _downloadScope by lazy { | ||||
|         CoroutineScope(downloadDispatcher + SupervisorJob()) | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Configure the dispatcher used for download operations | ||||
|      * @param dispatcher The coroutine dispatcher to use | ||||
|      */ | ||||
|     fun configureDispatcher(dispatcher: CoroutineDispatcher) { | ||||
|         Log.d(TAG, "Configuring download dispatcher to: $dispatcher") | ||||
|         downloadDispatcher = dispatcher | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Initialize with IO dispatcher (call this when system is stable) | ||||
|      */ | ||||
|     fun initializeWithIO() { | ||||
|         configureDispatcher(Dispatchers.IO) | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Initialize with Default dispatcher (call this when IO dispatcher has issues) | ||||
|      */ | ||||
|     fun initializeWithDefault() { | ||||
|         configureDispatcher(Dispatchers.Default) | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Launch a download coroutine with proper error handling | ||||
|      * @param block The suspend function to execute | ||||
|      */ | ||||
|     fun launchDownload(block: suspend CoroutineScope.() -> Unit) { | ||||
|         _downloadScope.launch(downloadDispatcher, block = block) | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Get a scope with download dispatcher for manual coroutine management | ||||
|      */ | ||||
|     fun getDownloadScope(): CoroutineScope { | ||||
|         return _downloadScope | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Reset to default configuration (for testing or recovery) | ||||
|      */ | ||||
|     fun resetToDefault() { | ||||
|         Log.d(TAG, "Resetting to default dispatcher (Dispatchers.Default)") | ||||
|         downloadDispatcher = Dispatchers.Default | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|      * Get current dispatcher info for debugging | ||||
|      */ | ||||
|     fun getCurrentDispatcherInfo(): String { | ||||
|         return "Current download dispatcher: $downloadDispatcher" | ||||
|     } | ||||
| }  | ||||
|  | @ -9,6 +9,7 @@ data class DownloadInfo( | |||
|     var totalSize: Long = 0, | ||||
|     var speedInfo: String = "", | ||||
|     var errorMessage: String? = null, | ||||
|     var errorException: Exception? = null, | ||||
|     var lastLogTime: Long = 0, | ||||
|     var lastProgressUpdateTime: Long = 0, | ||||
|     var progressStage: String = "", | ||||
|  |  | |||
|  | @ -26,6 +26,7 @@ object DownloadPersistentData { | |||
|     const val METADATA_KEY: String = "meta_data" | ||||
|     const val SIZE_TOTAL_KEY: String = "size_total" | ||||
|     const val SIZE_SAVED_KEY: String = "size_saved" | ||||
|     const val SIZE_MARKET_TOTAL_KEY: String = "size_market_total" | ||||
|     const val DOWNLOADED_TIME_KEY: String = "downloaded_time" | ||||
| 
 | ||||
|     // Create preference keys with modelId | ||||
|  | @ -38,6 +39,9 @@ object DownloadPersistentData { | |||
|     private fun createMetaDataKey(modelId: String): Preferences.Key<String> =  | ||||
|         stringPreferencesKey("${METADATA_KEY}_$modelId") | ||||
| 
 | ||||
|     private fun createSizeMarketTotalKey(modelId: String): Preferences.Key<Long> =  | ||||
|         longPreferencesKey("${SIZE_MARKET_TOTAL_KEY}_$modelId") | ||||
| 
 | ||||
|     private fun createDownloadedTimeKey(modelId: String): Preferences.Key<Long> =  | ||||
|         longPreferencesKey("${DOWNLOADED_TIME_KEY}_$modelId") | ||||
| 
 | ||||
|  | @ -204,6 +208,35 @@ object DownloadPersistentData { | |||
|          | ||||
|         return dataStoreValue ?: 0L | ||||
|     } | ||||
| 
 | ||||
|     fun saveMarketSizeTotal(context: Context, modelId: String, total: Long) { | ||||
|         runBlocking { saveMarketSizeTotalSuspend(context, modelId, total) } | ||||
|     } | ||||
|      | ||||
|     suspend fun saveMarketSizeTotalSuspend(context: Context, modelId: String, total: Long) { | ||||
|         val normalizedModelId = ModelUtils.safeModelId(modelId) | ||||
|         val key = createSizeMarketTotalKey(normalizedModelId) | ||||
|          | ||||
|         context.downloadDataStore.edit { preferences -> | ||||
|             preferences[key] = total | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fun getMarketSizeTotal(context: Context, modelId: String): Long { | ||||
|         return runBlocking { getMarketSizeTotalSuspend(context, modelId) } | ||||
|     } | ||||
|      | ||||
|     suspend fun getMarketSizeTotalSuspend(context: Context, modelId: String): Long { | ||||
|         val normalizedModelId = ModelUtils.safeModelId(modelId) | ||||
|         val key = createSizeMarketTotalKey(normalizedModelId) | ||||
|          | ||||
|         // Try to read from DataStore | ||||
|         val dataStoreValue = context.downloadDataStore.data | ||||
|             .map { preferences -> preferences[key] } | ||||
|             .first() | ||||
|          | ||||
|         return dataStoreValue ?: 0L | ||||
|     } | ||||
|      | ||||
|     // Private migration helpers | ||||
|     private fun deleteSharedPrefsFileIfEmpty(context: Context, preferencesName: String) { | ||||
|  |  | |||
|  | @ -24,8 +24,6 @@ import com.alibaba.mnnllm.android.utils.CurrentActivityTracker | |||
| import com.alibaba.mnnllm.android.utils.FileUtils | ||||
| import com.alibaba.mnnllm.android.utils.MmapUtils.clearMmapCache | ||||
| import kotlinx.coroutines.Dispatchers | ||||
| import kotlinx.coroutines.MainScope | ||||
| import kotlinx.coroutines.launch | ||||
| import kotlinx.coroutines.withContext | ||||
| import java.io.File | ||||
| import java.util.concurrent.ConcurrentHashMap | ||||
|  | @ -42,7 +40,7 @@ class LoggingDownloadListener : DownloadListener { | |||
|     } | ||||
| 
 | ||||
|     override fun onDownloadProgress(modelId: String, downloadInfo: DownloadInfo) { | ||||
|         Log.v(ModelDownloadManager.TAG, "onDownloadProgress: $modelId, progress: ${downloadInfo.progress}, state: ${downloadInfo.downloadState}, speed: ${downloadInfo.speedInfo}") | ||||
|         Log.v(ModelDownloadManager.TAG, "onDownloadProgress: $modelId, progress: ${downloadInfo.progress}, state: ${downloadInfo.downloadState}, speed: ${downloadInfo.speedInfo} stage: ${downloadInfo.progressStage}") | ||||
|     } | ||||
| 
 | ||||
|     override fun onDownloadFinished(modelId: String, path: String) { | ||||
|  | @ -253,7 +251,6 @@ class ModelDownloadManager private constructor(context: Context) { | |||
|     fun getDownloadInfo(modelId: String): DownloadInfo { | ||||
|         Log.d(TAG, "getDownloadInfo: $modelId totalSize: ${DownloadPersistentData.getDownloadSizeTotal(ApplicationProvider.get(), modelId)}" + | ||||
|                 " progress: ${getRealDownloadSize(modelId)}") | ||||
|         val source = ModelUtils.getSource(modelId)!! | ||||
|         if (!downloadInfoMap.containsKey(modelId)) { | ||||
|             val downloadInfo = DownloadInfo() | ||||
|             if (getDownloadedFile(modelId) != null) { | ||||
|  | @ -277,7 +274,11 @@ class ModelDownloadManager private constructor(context: Context) { | |||
|             } | ||||
|             downloadInfoMap[modelId] = downloadInfo | ||||
|             if (downloadInfo.totalSize < 100) { | ||||
|                 getRepoSize(modelId, source) | ||||
|                 val marketSize = DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId) | ||||
|                 Log.d(TAG, "getMarketSize for ${modelId} size: $marketSize") | ||||
|                 if (marketSize > 0) { | ||||
|                     downloadInfo.totalSize = marketSize | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         val downloadInfo = downloadInfoMap[modelId]!! | ||||
|  | @ -305,26 +306,6 @@ class ModelDownloadManager private constructor(context: Context) { | |||
|         return savedSize | ||||
|     } | ||||
| 
 | ||||
|     private fun getRepoSize( | ||||
|         modelId: String, | ||||
|         source: Any, | ||||
|         modelName: String = modelId | ||||
|     ) { | ||||
|         MainScope().launch { | ||||
|             val downloader = when (source) { | ||||
|                 is String -> getDownloaderForSource(source) | ||||
|                 is ModelSources.ModelSourceType -> getDownloaderForSource(source) | ||||
|                 else -> throw IllegalArgumentException("Unsupported source type: ${source::class.java.name}") | ||||
|             } | ||||
|             val repoSize = downloader.getRepoSize(modelId) | ||||
|             if (repoSize > 0 && DownloadPersistentData.getDownloadSizeTotal(ApplicationProvider.get(), modelName) <= 0L) { | ||||
|                 DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelName, repoSize) | ||||
|                 downloadInfoMap[modelName]?.totalSize = repoSize | ||||
|                 listeners.forEach { it.onDownloadTotalSize(modelName, repoSize) } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private fun setDownloadFinished(modelId: String, path: String) { | ||||
|         val downloadInfo = downloadInfoMap[modelId] ?: return | ||||
|         downloadInfo.downloadState = DownloadState.COMPLETED | ||||
|  | @ -422,6 +403,7 @@ class ModelDownloadManager private constructor(context: Context) { | |||
|         val info = downloadInfoMap.getOrPut(modelId) { DownloadInfo() } | ||||
|         info.downloadState = DownloadState.FAILED | ||||
|         info.errorMessage = e.message | ||||
|         info.errorException = e | ||||
|         listeners.forEach { | ||||
|             Log.d(TAG, "[setDownloadFailed] Notifying listener: ${it.javaClass.simpleName}") | ||||
|             it.onDownloadFailed(modelId, e) | ||||
|  | @ -470,9 +452,6 @@ class ModelDownloadManager private constructor(context: Context) { | |||
|         savedSize: Long, | ||||
|         totalSize: Long | ||||
|     ) { | ||||
|         if (totalSize <= 0) { | ||||
|             return | ||||
|         } | ||||
|         val downloadInfo = downloadInfoMap.getOrPut(modelId) { DownloadInfo() } | ||||
|         val currentTime = System.currentTimeMillis() | ||||
|         if (stage == downloadInfo.progressStage && currentTime - downloadInfo.lastProgressUpdateTime < 1000) { | ||||
|  | @ -486,11 +465,13 @@ class ModelDownloadManager private constructor(context: Context) { | |||
|         if (downloadInfo.savedSize <= 0) { | ||||
|             downloadInfo.savedSize = savedSize | ||||
|         } | ||||
|         downloadInfo.progress = savedSize.toDouble() / totalSize | ||||
|         DownloadPersistentData.saveDownloadSizeSaved(ApplicationProvider.get(), modelId, savedSize) | ||||
|         DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelId, totalSize) | ||||
|         calculateDownloadSpeed(downloadInfo, savedSize) | ||||
|         Log.v(TAG, "[updateDownloadingProgress] Notifying ${listeners.size} listeners for $modelId") | ||||
|         if (totalSize > 0) { | ||||
|             downloadInfo.progress = savedSize.toDouble() / totalSize | ||||
|             DownloadPersistentData.saveDownloadSizeSaved(ApplicationProvider.get(), modelId, savedSize) | ||||
|             DownloadPersistentData.saveDownloadSizeTotal(ApplicationProvider.get(), modelId, totalSize) | ||||
|             calculateDownloadSpeed(downloadInfo, savedSize) | ||||
|         } | ||||
|         Log.v(TAG, "[updateDownloadingProgress] Notifying ${listeners.size} listeners for $modelId stage: $stage") | ||||
|         listeners.forEach { it.onDownloadProgress(modelId, downloadInfo) } | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -25,8 +25,7 @@ import com.alibaba.mls.api.hf.HfFileMetadata | |||
| import com.alibaba.mls.api.hf.HfRepoInfo | ||||
| import com.alibaba.mnnllm.android.chat.model.ChatDataManager | ||||
| import com.alibaba.mnnllm.android.utils.TimeUtils | ||||
| import kotlinx.coroutines.Dispatchers | ||||
| import kotlinx.coroutines.runBlocking | ||||
| import com.alibaba.mls.api.download.DownloadCoroutineManager | ||||
| import kotlinx.coroutines.withContext | ||||
| import okhttp3.OkHttpClient | ||||
| import java.io.File | ||||
|  | @ -60,10 +59,11 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|      * Unified method to fetch repo information from HuggingFace API | ||||
|      * @param modelId the model ID to fetch info for | ||||
|      * @param calculateSize whether to calculate repo size (requires additional network requests) | ||||
|      * @return HfRepoInfo object or null if failed | ||||
|      * @return HfRepoInfo object | ||||
|      * @throws FileDownloadException if failed to fetch repo info | ||||
|      */ | ||||
|     private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): HfRepoInfo? { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|     private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): HfRepoInfo { | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val hfModelId = hfModelId(modelId) | ||||
|                 val response = getHfApiClient().apiService.getRepoInfo(hfModelId, "main")?.execute() | ||||
|  | @ -79,11 +79,16 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|                     callback?.onRepoInfo(modelId, lastModified, repoSize) | ||||
|                     repoInfo | ||||
|                 } else { | ||||
|                     null | ||||
|                     val errorMsg = if (response?.isSuccessful == false) { | ||||
|                         "API request failed with code ${response.code()}: ${response.message()}" | ||||
|                     } else { | ||||
|                         "API response was null or empty" | ||||
|                     } | ||||
|                     throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg") | ||||
|                 } | ||||
|             }.getOrElse { exception -> | ||||
|                 Log.e(TAG, "Failed to fetch repo info for $modelId", exception) | ||||
|                 null | ||||
|                 throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}") | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | @ -92,7 +97,7 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|      * Calculate total size of all files in the repo | ||||
|      */ | ||||
|     private suspend fun calculateRepoSize(hfRepoInfo: HfRepoInfo): Long { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val metaList = requestMetaDataList(hfRepoInfo) | ||||
|                 metaList.sumOf { it?.size ?: 0L } | ||||
|  | @ -101,20 +106,22 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|     } | ||||
| 
 | ||||
|     override fun download(modelId: String) { | ||||
|         executor!!.submit { | ||||
|             runBlocking { | ||||
|         DownloadCoroutineManager.launchDownload { | ||||
|             try { | ||||
|                 val repoInfo = fetchRepoInfo(modelId) | ||||
|                 if (repoInfo != null) { | ||||
|                     downloadHfRepo(repoInfo) | ||||
|                 } else { | ||||
|                     callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId")) | ||||
|                 } | ||||
|                 downloadHfRepo(repoInfo) | ||||
|             } catch (e: FileDownloadException) { | ||||
|                 callback?.onDownloadFailed(modelId, e) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override suspend fun checkUpdate(modelId: String) { | ||||
|         fetchRepoInfo(modelId) | ||||
|         try { | ||||
|             fetchRepoInfo(modelId) | ||||
|         } catch (e: FileDownloadException) { | ||||
|             Log.e(TAG, "Failed to check update for $modelId", e) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fun downloadHfRepo(hfRepoInfo: HfRepoInfo) { | ||||
|  | @ -127,17 +134,23 @@ class HfModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|     } | ||||
| 
 | ||||
|     override suspend fun getRepoSize(modelId: String): Long { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val repoInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|                 if (repoInfo != null) { | ||||
|                     // Size was already calculated in fetchRepoInfo, but we can also calculate it directly | ||||
|                     calculateRepoSize(repoInfo) | ||||
|                 } else { | ||||
|                     0L | ||||
|                 } | ||||
|                 // Size was already calculated in fetchRepoInfo, but we can also calculate it directly | ||||
|                 calculateRepoSize(repoInfo) | ||||
|             }.getOrElse { exception -> | ||||
|                 Log.e(TAG, "Failed to get repo size for $modelId", exception) | ||||
|                 // Try to get file_size from saved market data as fallback | ||||
|                 try { | ||||
|                     val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId) | ||||
|                     if (marketSize > 0) { | ||||
|                         Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize") | ||||
|                         return@withContext marketSize | ||||
|                     } | ||||
|                 } catch (e: Exception) { | ||||
|                     Log.w(TAG, "Failed to get saved market size for $modelId", e) | ||||
|                 } | ||||
|                 0L | ||||
|             } | ||||
|         } | ||||
|  |  | |||
|  | @ -22,8 +22,7 @@ import com.alibaba.mls.api.ml.FileInfo | |||
| import com.alibaba.mls.api.ml.MlApiClient | ||||
| import com.alibaba.mls.api.ml.MlRepoInfo | ||||
| import com.alibaba.mnnllm.android.model.ModelUtils | ||||
| import kotlinx.coroutines.Dispatchers | ||||
| import kotlinx.coroutines.launch | ||||
| import com.alibaba.mls.api.download.DownloadCoroutineManager | ||||
| import kotlinx.coroutines.withContext | ||||
| import java.io.File | ||||
| import com.alibaba.mls.api.ml.MlRepoData | ||||
|  | @ -43,15 +42,16 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|      * Unified method to fetch repo information from Modelers API | ||||
|      * @param modelId the model ID to fetch info for | ||||
|      * @param calculateSize whether to calculate repo size (requires additional network requests) | ||||
|      * @return MlRepoInfo object or null if failed | ||||
|      * @return MlRepoInfo object | ||||
|      * @throws FileDownloadException if failed to fetch repo info | ||||
|      */ | ||||
|     private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MlRepoInfo? { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|     private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MlRepoInfo { | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val modelersId = ModelUtils.getRepositoryPath(modelId) | ||||
|                 val split = modelersId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() | ||||
|                 if (split.size != 2) { | ||||
|                     return@runCatching null | ||||
|                     throw FileDownloadException("Invalid model ID format for $modelId, expected format: owner/repo") | ||||
|                 } | ||||
|                  | ||||
|                 val response = mlApiClient.apiService.getModelFiles(split[0], split[1], "").execute() | ||||
|  | @ -84,24 +84,33 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|                     callback?.onRepoInfo(modelId, lastModified, repoSize) | ||||
|                     repoInfo | ||||
|                 } else { | ||||
|                     null | ||||
|                     val errorMsg = if (!response.isSuccessful) { | ||||
|                         "API request failed with code ${response.code()}: ${response.message()}" | ||||
|                     } else { | ||||
|                         "API response was null or empty" | ||||
|                     } | ||||
|                     throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg") | ||||
|                 } | ||||
|             }.getOrElse { exception -> | ||||
|                 Log.e(TAG, "Failed to fetch repo info for $modelId", exception) | ||||
|                 null | ||||
|                 throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}") | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override fun download(modelId: String) { | ||||
|         Log.d(TAG, "start download  ${modelId}") | ||||
|         DownloadExecutor.executeScope.launch { | ||||
|         DownloadCoroutineManager.launchDownload { | ||||
|             downloadRepo(modelId) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override suspend fun checkUpdate(modelId: String) { | ||||
|         fetchRepoInfo(modelId, false) | ||||
|         try { | ||||
|             fetchRepoInfo(modelId, false) | ||||
|         } catch (e: FileDownloadException) { | ||||
|             Log.e(TAG, "Failed to check update for $modelId", e) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override fun getDownloadPath(modelId: String): File { | ||||
|  | @ -125,16 +134,22 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|     } | ||||
| 
 | ||||
|     override suspend fun getRepoSize(modelId: String): Long { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val repoInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|                 if (repoInfo != null) { | ||||
|                     repoInfo.data.tree.filter { it.type != "dir" }.sumOf { it.size } | ||||
|                 } else { | ||||
|                     0L | ||||
|                 } | ||||
|                 repoInfo.data.tree.filter { it.type != "dir" }.sumOf { it.size } | ||||
|             }.getOrElse { exception -> | ||||
|                 Log.e(TAG, "Failed to get repo size for $modelId", exception) | ||||
|                 // Try to get file_size from saved market data as fallback | ||||
|                 try { | ||||
|                     val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId) | ||||
|                     if (marketSize > 0) { | ||||
|                         Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize") | ||||
|                         return@withContext marketSize | ||||
|                     } | ||||
|                 } catch (e: Exception) { | ||||
|                     Log.w(TAG, "Failed to get saved market size for $modelId", e) | ||||
|                 } | ||||
|                 0L | ||||
|             } | ||||
|         } | ||||
|  | @ -149,15 +164,16 @@ class MLModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|             return null | ||||
|         } | ||||
|          | ||||
|         val modelInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|         if (modelInfo != null) { | ||||
|         try { | ||||
|             val modelInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|             callback?.onDownloadTaskAdded() | ||||
|             downloadMlRepoInner(modelId, modelersId, modelInfo) | ||||
|             callback?.onDownloadTaskRemoved() | ||||
|         } else { | ||||
|             callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId")) | ||||
|             return modelInfo | ||||
|         } catch (e: FileDownloadException) { | ||||
|             callback?.onDownloadFailed(modelId, e) | ||||
|             return null | ||||
|         } | ||||
|         return modelInfo | ||||
|     } | ||||
| 
 | ||||
|     private fun getAllFiles(owner: String, repo: String, path: String, allFiles: MutableList<FileInfo>) { | ||||
|  |  | |||
|  | @ -6,7 +6,6 @@ import android.util.Log | |||
| import com.alibaba.mls.api.ApplicationProvider | ||||
| import com.alibaba.mls.api.FileDownloadException | ||||
| import com.alibaba.mls.api.hf.HfFileMetadata | ||||
| import com.alibaba.mls.api.download.DownloadExecutor.Companion.executor | ||||
| import com.alibaba.mls.api.download.DownloadFileUtils.createSymlink | ||||
| import com.alibaba.mls.api.download.DownloadFileUtils.deleteDirectoryRecursively | ||||
| import com.alibaba.mls.api.download.DownloadFileUtils.getLastFileName | ||||
|  | @ -22,9 +21,10 @@ import com.alibaba.mls.api.ms.MsApiClient | |||
| import com.alibaba.mls.api.ms.MsRepoInfo | ||||
| import com.alibaba.mnnllm.android.chat.model.ChatDataManager | ||||
| import com.alibaba.mnnllm.android.model.ModelUtils | ||||
| import com.alibaba.mnnllm.android.modelmarket.ModelMarketItem | ||||
| import com.alibaba.mnnllm.android.modelmarket.ModelRepository | ||||
| import com.alibaba.mnnllm.android.utils.TimeUtils | ||||
| import kotlinx.coroutines.Dispatchers | ||||
| import kotlinx.coroutines.runBlocking | ||||
| import com.alibaba.mls.api.download.DownloadCoroutineManager | ||||
| import kotlinx.coroutines.withContext | ||||
| import retrofit2.Call | ||||
| import retrofit2.Callback | ||||
|  | @ -44,15 +44,18 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|      * Unified method to fetch repo information from ModelScope API | ||||
|      * @param modelId the model ID to fetch info for | ||||
|      * @param calculateSize whether to calculate repo size (not needed for ModelScope as size is included) | ||||
|      * @return MsRepoInfo object or null if failed | ||||
|      * @return MsRepoInfo object | ||||
|      * @throws FileDownloadException if failed to fetch repo info | ||||
|      */ | ||||
|     private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MsRepoInfo? { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|     private suspend fun fetchRepoInfo(modelId: String, calculateSize: Boolean = false): MsRepoInfo { | ||||
|         Log.d(TAG, "fetchRepoInfo called for modelId: $modelId, calculateSize: $calculateSize") | ||||
|          | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val msModelId = ModelUtils.getRepositoryPath(modelId) | ||||
|                 val split = msModelId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() | ||||
|                 if (split.size != 2) { | ||||
|                     return@runCatching null | ||||
|                     throw FileDownloadException("Invalid model ID format for $modelId, expected format: owner/repo") | ||||
|                 } | ||||
|                  | ||||
|                 val response = msApiClient.apiService.getModelFiles(split[0], split[1]).execute() | ||||
|  | @ -65,25 +68,34 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|                     callback?.onRepoInfo(modelId, lastModified, repoSize) | ||||
|                     repoInfo | ||||
|                 } else { | ||||
|                     null | ||||
|                     val errorMsg = if (!response.isSuccessful) { | ||||
|                         "API request failed with code ${response.code()}: ${response.message()}" | ||||
|                     } else { | ||||
|                         "API response was null or empty" | ||||
|                     } | ||||
|                     throw FileDownloadException("Failed to fetch repo info for $modelId: $errorMsg") | ||||
|                 } | ||||
|             }.getOrElse { exception -> | ||||
|                 Log.e(TAG, "Failed to fetch repo info for $modelId", exception) | ||||
|                 null | ||||
|                 throw FileDownloadException("Failed to fetch repo info for $modelId: ${exception.message}") | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override fun download(modelId: String) { | ||||
|         executor!!.submit { | ||||
|             kotlinx.coroutines.runBlocking { | ||||
|                 downloadMsRepo(modelId) | ||||
|             } | ||||
|         Log.d(TAG, "MsModelDownloader download: $modelId") | ||||
|          | ||||
|         DownloadCoroutineManager.launchDownload { | ||||
|             downloadMsRepo(modelId) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override suspend fun checkUpdate(modelId: String) { | ||||
|         fetchRepoInfo(modelId) | ||||
|         try { | ||||
|             fetchRepoInfo(modelId) | ||||
|         } catch (e: FileDownloadException) { | ||||
|             Log.e(TAG, "Failed to check update for $modelId", e) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     override fun getDownloadPath(modelId: String): File { | ||||
|  | @ -107,16 +119,22 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
|     } | ||||
| 
 | ||||
|     override suspend fun getRepoSize(modelId: String): Long { | ||||
|         return withContext(Dispatchers.IO) { | ||||
|         return withContext(DownloadCoroutineManager.downloadDispatcher) { | ||||
|             runCatching { | ||||
|                 val repoInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|                 if (repoInfo != null) { | ||||
|                     repoInfo.Data?.Files?.filter { it.Type != "tree" }?.sumOf { it.Size } ?: 0L | ||||
|                 } else { | ||||
|                     0L | ||||
|                 } | ||||
|                 repoInfo.Data?.Files?.filter { it.Type != "tree" }?.sumOf { it.Size } ?: 0L | ||||
|             }.getOrElse { exception -> | ||||
|                 Log.e(TAG, "Failed to get repo size for $modelId", exception) | ||||
|                 // Try to get file_size from saved market data as fallback | ||||
|                 try { | ||||
|                     val marketSize = com.alibaba.mls.api.download.DownloadPersistentData.getMarketSizeTotal(ApplicationProvider.get(), modelId) | ||||
|                     if (marketSize > 0) { | ||||
|                         Log.d(TAG, "Using saved market size as fallback for $modelId: $marketSize") | ||||
|                         return@withContext marketSize | ||||
|                     } | ||||
|                 } catch (e: Exception) { | ||||
|                     Log.w(TAG, "Failed to get saved market size for $modelId", e) | ||||
|                 } | ||||
|                 0L | ||||
|             } | ||||
|         } | ||||
|  | @ -124,20 +142,20 @@ class MsModelDownloader(override var callback: ModelRepoDownloadCallback?, | |||
| 
 | ||||
|     private suspend fun downloadMsRepo(modelId: String) { | ||||
|         val modelScopeId = ModelUtils.getRepositoryPath(modelId) | ||||
|         Log.d(TAG, "MsModelDownloader downloadMsRepo: $modelId modelScopeId : $modelScopeId") | ||||
|         val split = modelScopeId.split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() | ||||
|         if (split.size != 2) { | ||||
|             callback?.onDownloadFailed(modelId, FileDownloadException("getRepoInfoFailed modelId format error: $modelId")) | ||||
|             return | ||||
|         } | ||||
|          | ||||
|         val repoInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|         if (repoInfo != null) { | ||||
|             Log.d(TAG, "downloadMsRepoInner executor") | ||||
|         try { | ||||
|             val repoInfo = fetchRepoInfo(modelId, calculateSize = true) | ||||
|             Log.d(TAG, "downloadMsRepo repoInfo: $repoInfo") | ||||
|             callback?.onDownloadTaskAdded() | ||||
|             downloadMsRepoInner(modelId, modelScopeId, repoInfo) | ||||
|             callback?.onDownloadTaskRemoved() | ||||
|         } else { | ||||
|             callback?.onDownloadFailed(modelId, FileDownloadException("Failed to get repo info for $modelId")) | ||||
|         } catch (e: FileDownloadException) { | ||||
|             callback?.onDownloadFailed(modelId, e) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -69,7 +69,7 @@ class BenchmarkStateMachine { | |||
|     fun isValidTransition(from: BenchmarkState, to: BenchmarkState): Boolean { | ||||
|         val validTransitions = when (from) { | ||||
|             BenchmarkState.IDLE -> listOf(BenchmarkState.LOADING_MODELS) | ||||
|             BenchmarkState.LOADING_MODELS -> listOf(BenchmarkState.READY, BenchmarkState.ERROR_MODEL_NOT_FOUND) | ||||
|             BenchmarkState.LOADING_MODELS -> listOf(BenchmarkState.READY, BenchmarkState.ERROR_MODEL_NOT_FOUND, BenchmarkState.ERROR) | ||||
|             BenchmarkState.READY -> listOf(BenchmarkState.INITIALIZING, BenchmarkState.LOADING_MODELS) | ||||
|             BenchmarkState.INITIALIZING -> listOf(BenchmarkState.RUNNING, BenchmarkState.ERROR) | ||||
|             BenchmarkState.RUNNING -> listOf(BenchmarkState.STOPPING, BenchmarkState.COMPLETED, BenchmarkState.ERROR) | ||||
|  |  | |||
|  | @ -356,6 +356,9 @@ class MarketItemHolder( | |||
|                 R.id.menu_open_model_card -> { | ||||
|                     openModelCard(itemView.context, modelMarketItem) | ||||
|                 } | ||||
|                 R.id.menu_copy_error_info -> { | ||||
|                     copyErrorInfoToClipboard(modelMarketItemWrapper) | ||||
|                 } | ||||
|             } | ||||
|             true | ||||
|         } | ||||
|  | @ -388,6 +391,10 @@ class MarketItemHolder( | |||
|          | ||||
|         // Model card: always visible | ||||
|         menu.findItem(R.id.menu_open_model_card).isVisible = false | ||||
|          | ||||
|         // Copy error info: visible only for failed downloads | ||||
|         menu.findItem(R.id.menu_copy_error_info).isVisible =  | ||||
|             downloadState == DownloadState.FAILED | ||||
|     } | ||||
| 
 | ||||
|     private fun handleSettingsMenu(modelMarketItem: ModelMarketItem) { | ||||
|  | @ -409,6 +416,41 @@ class MarketItemHolder( | |||
|     private fun openModelCard(context: android.content.Context, modelMarketItem: ModelMarketItem) { | ||||
|     } | ||||
| 
 | ||||
|     private fun copyErrorInfoToClipboard(modelMarketItemWrapper: ModelMarketItemWrapper) { | ||||
|         val context = itemView.context | ||||
|         val downloadInfo = modelMarketItemWrapper.downloadInfo | ||||
|         val modelMarketItem = modelMarketItemWrapper.modelMarketItem | ||||
|          | ||||
|         if (downloadInfo.downloadState != DownloadState.FAILED) { | ||||
|             return | ||||
|         } | ||||
|          | ||||
|         val errorMessage = downloadInfo.errorMessage ?: "Unknown error" | ||||
|         val modelInfo = "Model: ${modelMarketItem.modelName} (${modelMarketItem.modelId})" | ||||
|          | ||||
|         // Build error info with stack trace if available | ||||
|         val errorInfoBuilder = StringBuilder().apply { | ||||
|             appendLine(modelInfo) | ||||
|             appendLine("Error: $errorMessage") | ||||
|             appendLine("Time: ${java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss", java.util.Locale.getDefault()).format(java.util.Date())}") | ||||
|              | ||||
|             // Add stack trace if exception is available | ||||
|             downloadInfo.errorException?.let { exception -> | ||||
|                 appendLine() | ||||
|                 appendLine("Stack Trace:") | ||||
|                 appendLine(android.util.Log.getStackTraceString(exception)) | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         val errorInfo = errorInfoBuilder.toString() | ||||
|          | ||||
|         val clipboard = context.getSystemService(android.content.Context.CLIPBOARD_SERVICE) as android.content.ClipboardManager | ||||
|         val clip = android.content.ClipData.newPlainText("Error Info", errorInfo) | ||||
|         clipboard.setPrimaryClip(clip) | ||||
|          | ||||
|         Toast.makeText(context, context.getString(R.string.error_info_copied), Toast.LENGTH_SHORT).show() | ||||
|     } | ||||
| 
 | ||||
|     companion object { | ||||
|         const val TAG = "ModelItemHolder" | ||||
|     } | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ data class ModelMarketItem( | |||
|     val categories: List<String>, | ||||
|     val sources: Map<String, String>, | ||||
|     val description: String? = null, | ||||
|     @SerializedName("file_size") val fileSize: Long = 0L, // File size in bytes from model_market.json | ||||
|     var currentSource: String = "", // e.g. "modelscope", "huggingface" | ||||
|     var currentRepoPath: String = "", // e.g. "MNN/Qwen-1.8B-Chat-Int4" | ||||
|     var modelId: String = "" // e.g. "ModelScope/MNN/Qwen-1.8B-Chat-Int4" | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ package com.alibaba.mnnllm.android.modelmarket | |||
| import android.content.Context | ||||
| import android.util.Log | ||||
| import androidx.preference.PreferenceManager | ||||
| import com.alibaba.mls.api.download.DownloadPersistentData | ||||
| import com.alibaba.mnnllm.android.mainsettings.MainSettings | ||||
| import com.google.gson.Gson | ||||
| import kotlinx.coroutines.Dispatchers | ||||
|  | @ -233,6 +234,16 @@ class ModelRepository(private val context: Context) { | |||
|             item.currentSource = selectedSource | ||||
|             item.currentRepoPath = item.sources[selectedSource]!! | ||||
|             item.modelId = "$selectedSource/${item.sources[selectedSource]!!}" | ||||
|              | ||||
|             // Save file_size to DownloadPersistentData if available | ||||
|             if (item.fileSize > 0) { | ||||
|                 try { | ||||
|                     DownloadPersistentData.saveMarketSizeTotalSuspend(context, item.modelId, item.fileSize) | ||||
|                     Log.d(TAG, "Saved market size for ${item.modelId}: ${item.fileSize}") | ||||
|                 } catch (e: Exception) { | ||||
|                     Log.w(TAG, "Failed to save market size for ${item.modelId}", e) | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,4 +19,7 @@ | |||
|     <item | ||||
|         android:id="@+id/menu_open_model_card" | ||||
|         android:title="@string/open_model_card" /> | ||||
|     <item | ||||
|         android:id="@+id/menu_copy_error_info" | ||||
|         android:title="@string/menu_copy_error_info" /> | ||||
| </menu>  | ||||
|  | @ -485,4 +485,6 @@ | |||
|     <string name="allow_network_market_data">允许联网获取模型市场数据</string> | ||||
|     <string name="enable_network_delay">启用网络延时(3秒)</string> | ||||
|     <string name="update_available">有新版本</string> | ||||
|     <string name="menu_copy_error_info">复制错误信息</string> | ||||
|     <string name="error_info_copied">错误信息已复制到剪贴板</string> | ||||
| </resources> | ||||
|  |  | |||
|  | @ -488,4 +488,6 @@ | |||
|     <string name="allow_network_market_data">Allow network to fetch model market data</string> | ||||
|     <string name="enable_network_delay">Enable network delay (3 seconds)</string> | ||||
|     <string name="update_available">Update Available</string> | ||||
|     <string name="menu_copy_error_info">Copy Error Info</string> | ||||
|     <string name="error_info_copied">Error information copied to clipboard</string> | ||||
| </resources> | ||||
|  | @ -0,0 +1,126 @@ | |||
| import json | ||||
| import argparse | ||||
| from huggingface_hub import HfApi | ||||
| from huggingface_hub.utils import HfHubHTTPError | ||||
| 
 | ||||
| def get_repo_size_in_bytes(repo_id: str) -> int: | ||||
|     """ | ||||
|     Calculates the total size of all files in a Hugging Face model repository. | ||||
| 
 | ||||
|     Args: | ||||
|         repo_id (str): The ID of the model repository, e.g., 'google-bert/bert-base-uncased'. | ||||
| 
 | ||||
|     Returns: | ||||
|         int: The total size of the files in bytes. Returns -1 if the repository | ||||
|              cannot be found or an error occurs. | ||||
|     """ | ||||
|     # Initialize the HfApi client | ||||
|     api = HfApi() | ||||
|     total_size = 0 | ||||
|      | ||||
|     try: | ||||
|         # Fetch model information, including metadata for each file | ||||
|         print(f"Fetching metadata for repository: '{repo_id}'...") | ||||
|         model_info = api.model_info(repo_id=repo_id, files_metadata=True) | ||||
|          | ||||
|         # Sum the size of each file in the repository | ||||
|         for file in model_info.siblings: | ||||
|             if file.size is not None: | ||||
|                 total_size += file.size | ||||
|          | ||||
|         print(f"Successfully calculated size for '{repo_id}': {total_size} bytes") | ||||
|         return total_size | ||||
|          | ||||
|     except HfHubHTTPError as e: | ||||
|         # Handle cases where the repository is not found or other HTTP errors | ||||
|         print(f"Error: Could not retrieve info for repository '{repo_id}'. It might not exist or be private.") | ||||
|         print(f"Details: {e}") | ||||
|         return -1 | ||||
|     except Exception as e: | ||||
|         # Handle other potential exceptions | ||||
|         print(f"An unexpected error occurred while processing '{repo_id}': {e}") | ||||
|         return -1 | ||||
| 
 | ||||
| def process_model_list(models: list): | ||||
|     """ | ||||
|     Iterates through a list of model objects and adds the 'file_size' field. | ||||
| 
 | ||||
|     Args: | ||||
|         models (list): A list of dictionaries, where each dictionary represents a model. | ||||
|     """ | ||||
|     if not models: | ||||
|         return # Do nothing if the list is empty | ||||
| 
 | ||||
|     for model in models: | ||||
|         # Check if the model has a HuggingFace source | ||||
|         if 'sources' in model and 'HuggingFace' in model['sources']: | ||||
|             repo_id = model['sources']['HuggingFace'] | ||||
|             if repo_id: | ||||
|                 # Get the repository size and add it to the model object | ||||
|                 file_size = get_repo_size_in_bytes(repo_id) | ||||
|                 model['file_size'] = file_size | ||||
|             else: | ||||
|                 # Handle cases where the HuggingFace repo_id is empty | ||||
|                 model['file_size'] = -1 | ||||
|                 print(f"Warning: Empty HuggingFace repo_id for model '{model.get('modelName', 'N/A')}'.") | ||||
|         else: | ||||
|             # If no HuggingFace source, you can decide what to do. | ||||
|             # Here we skip adding the field, or you could add 'file_size': 0 or -1 | ||||
|             print(f"Skipping model '{model.get('modelName', 'N/A')}' as it has no HuggingFace source.") | ||||
| 
 | ||||
| 
 | ||||
| def main(): | ||||
|     """ | ||||
|     Main function to parse arguments, read the input JSON, process it, | ||||
|     and write to the output JSON file. | ||||
|     """ | ||||
|     # Set up argument parser for command-line interface | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Process a market config JSON file to add Hugging Face model sizes." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "-i", "--input", | ||||
|         required=True, | ||||
|         help="Path to the input JSON file." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "-o", "--output", | ||||
|         required=True, | ||||
|         help="Path to the output JSON file." | ||||
|     ) | ||||
|      | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     # Read the input JSON file | ||||
|     try: | ||||
|         with open(args.input, 'r', encoding='utf-8') as f: | ||||
|             data = json.load(f) | ||||
|     except FileNotFoundError: | ||||
|         print(f"Error: Input file not found at '{args.input}'") | ||||
|         return | ||||
|     except json.JSONDecodeError: | ||||
|         print(f"Error: Could not decode JSON from the input file '{args.input}'.") | ||||
|         return | ||||
| 
 | ||||
|     # Define the keys that contain lists of models to process | ||||
|     model_list_keys = ['models', 'tts_models', 'asr_models'] | ||||
| 
 | ||||
|     # Process each list of models | ||||
|     for key in model_list_keys: | ||||
|         if key in data and isinstance(data[key], list): | ||||
|             print(f"\n--- Processing models in '{key}' ---") | ||||
|             process_model_list(data[key]) | ||||
|         else: | ||||
|             print(f"\n--- No models found in '{key}', skipping ---") | ||||
| 
 | ||||
|     # Write the updated data to the output JSON file | ||||
|     try: | ||||
|         with open(args.output, 'w', encoding='utf-8') as f: | ||||
|             json.dump(data, f, ensure_ascii=False, indent=4) | ||||
|         print(f"\nProcessing complete. Output successfully written to '{args.output}'") | ||||
|     except IOError as e: | ||||
|         print(f"Error: Could not write to output file '{args.output}'. Details: {e}") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|  | @ -102,7 +102,7 @@ final class LLMChatInteractor: ChatInteractorProtocol { | |||
| //                    PerformanceMonitor.shared.measureExecutionTime(operation: "String concatenation") { | ||||
|                         var updateLastMsg = self?.chatState.value[(self?.chatState.value.count ?? 1) - 1] | ||||
|                          | ||||
|                         if let isDeepSeek = self?.modelInfo.modelName.lowercased().contains("deepseek"), isDeepSeek == true, | ||||
|                         if self?.modelInfo.tags.contains("Think") == true, | ||||
|                             let text = self?.processor.process(progress: message.text) { | ||||
|                             updateLastMsg?.text = text | ||||
|                         } else { | ||||
|  |  | |||
|  | @ -15,7 +15,7 @@ struct ChatHistoryItemView: View { | |||
|              | ||||
|             if let lastMessage = getLastNonEmptyMessage() { | ||||
|                 Text(String(lastMessage.content.prefix(200))) | ||||
|                     .lineLimit(1) | ||||
|                     .lineLimit(3) | ||||
|                     .font(.system(size: 15, weight: .medium)) | ||||
|                     .foregroundColor(.primary) | ||||
|             } | ||||
|  |  | |||
|  | @ -709,9 +709,13 @@ bool remove_directory_safely(const std::string& path) { | |||
|             } | ||||
|              | ||||
|             std::string inputStr = [input UTF8String]; | ||||
|             #ifdef DEBUG | ||||
|             if (inputStr == "benchmark") { | ||||
|                 [blockSelf performBenchmarkWithOutput:&os]; | ||||
|             } else { | ||||
|             #else | ||||
|             { | ||||
|             #endif | ||||
|                 // Get initial context state for performance measurement | ||||
|                 auto context = blockSelf->_llm->getContext(); | ||||
|                 int initial_prompt_len = context->prompt_len; | ||||
|  |  | |||
|  | @ -1308,7 +1308,14 @@ | |||
|       } | ||||
|     }, | ||||
|     "Yes" : { | ||||
| 
 | ||||
|       "localizations" : { | ||||
|         "zh-Hans" : { | ||||
|           "stringUnit" : { | ||||
|             "state" : "translated", | ||||
|             "value" : "是" | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     }, | ||||
|     "搜索本地模型..." : { | ||||
|       "localizations" : { | ||||
|  |  | |||
|  | @ -125,6 +125,7 @@ class BenchmarkService: ObservableObject { | |||
|      | ||||
|     /// Releases the current model and frees associated resources | ||||
|     func releaseModel() { | ||||
|         llmEngine?.cancelInference() | ||||
|         llmEngine = nil | ||||
|         currentModelId = nil | ||||
|     } | ||||
|  |  | |||
|  | @ -34,6 +34,7 @@ class BenchmarkViewModel: ObservableObject { | |||
|      | ||||
|     @Published var startButtonText = String(localized: "Start Test") | ||||
|     @Published var isStartButtonEnabled = true | ||||
|     @Published var showStopConfirmation = false | ||||
|      | ||||
|     // MARK: - Private Properties | ||||
|      | ||||
|  | @ -105,6 +106,7 @@ class BenchmarkViewModel: ObservableObject { | |||
|      | ||||
|     /// Handles benchmark stop confirmation | ||||
|     func onStopBenchmarkTapped() { | ||||
|         showStopConfirmation = false | ||||
|         stopBenchmark() | ||||
|     } | ||||
|      | ||||
|  | @ -119,8 +121,8 @@ class BenchmarkViewModel: ObservableObject { | |||
|         showResults = false | ||||
|         hideStatus() | ||||
|          | ||||
|         // Release model to free memory | ||||
|         benchmarkService.releaseModel() | ||||
|         // Clean up resources when deleting results | ||||
|         cleanupBenchmarkResources() | ||||
|     } | ||||
|      | ||||
|     /// Placeholder for future result submission functionality | ||||
|  | @ -200,14 +202,16 @@ class BenchmarkViewModel: ObservableObject { | |||
|     private func stopBenchmark() { | ||||
|         updateStatus("Stopping benchmark...") | ||||
|         benchmarkService.stopBenchmark() | ||||
|         MemoryMonitor.shared.stop() | ||||
|         cleanupBenchmarkResources() | ||||
|     } | ||||
|      | ||||
|     // MARK: - UI State Management | ||||
|      | ||||
|     /// Updates UI state when benchmark starts | ||||
|     private func onBenchmarkStarted() { | ||||
|         isRunning = true | ||||
|         isStartButtonEnabled = true | ||||
|         startButtonText = String(localized: "Stop Test") | ||||
|         showProgressBar = true | ||||
|         showResults = false | ||||
|         updateStatus("Initializing benchmark...") | ||||
|  | @ -215,11 +219,22 @@ class BenchmarkViewModel: ObservableObject { | |||
|      | ||||
|     /// Resets UI to initial state | ||||
|     private func resetUIState() { | ||||
|         isRunning = false | ||||
|         isStartButtonEnabled = true | ||||
|         startButtonText = String(localized: "Start Test") | ||||
|         showProgressBar = false | ||||
|         hideStatus() | ||||
|         showResults = false | ||||
|         cleanupBenchmarkResources() | ||||
|     } | ||||
|      | ||||
|     /// Cleans up benchmark resources including memory monitoring and model | ||||
|     private func cleanupBenchmarkResources() { | ||||
|         MemoryMonitor.shared.stop() | ||||
|         MemoryMonitor.shared.reset() | ||||
|          | ||||
|         // Release model to free memory | ||||
|         benchmarkService.releaseModel() | ||||
|     } | ||||
|      | ||||
|     /// Updates status message display | ||||
|  | @ -238,9 +253,9 @@ class BenchmarkViewModel: ObservableObject { | |||
|         showError = true | ||||
|     } | ||||
|      | ||||
|     /// Placeholder for stop confirmation alert (handled in View) | ||||
|     /// Shows stop confirmation alert | ||||
|     private func showStopConfirmationAlert() { | ||||
|         // This will be handled in the View with an alert | ||||
|         showStopConfirmation = true | ||||
|     } | ||||
|      | ||||
|     /// Formats progress messages with appropriate status text based on progress type | ||||
|  | @ -309,11 +324,14 @@ extension BenchmarkViewModel: BenchmarkCallback { | |||
|         benchmarkResults = results | ||||
|         showResults = true | ||||
|          | ||||
|         // Only stop memory monitoring if benchmark is no longer running (all tests completed) | ||||
|         if !isRunning { | ||||
|             // Stop memory monitoring | ||||
|             MemoryMonitor.shared.stop() | ||||
|         } | ||||
|         // Update UI state to reflect completion | ||||
|         isRunning = false | ||||
|         isStartButtonEnabled = true | ||||
|         startButtonText = String(localized: "Start Test") | ||||
|         showProgressBar = false | ||||
|          | ||||
|         // Clean up resources after benchmark completion | ||||
|         cleanupBenchmarkResources() | ||||
|          | ||||
|         // Always hide status after processing results | ||||
|         hideStatus() | ||||
|  | @ -324,9 +342,16 @@ extension BenchmarkViewModel: BenchmarkCallback { | |||
|     /// Handles benchmark errors with user-friendly error messages | ||||
|     func onBenchmarkError(_ errorCode: Int, _ message: String) { | ||||
|         let errorCodeName = BenchmarkErrorCode(rawValue: errorCode)?.description ?? "Unknown" | ||||
|         showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)") | ||||
|          | ||||
|         // Check if this is a user-initiated stop - don't show error dialog | ||||
|         if errorCode == BenchmarkErrorCode.benchmarkStopped.rawValue { | ||||
|             print("BenchmarkViewModel: Benchmark stopped by user (\(errorCode)): \(message)") | ||||
|         } else { | ||||
|             showErrorMessage("Benchmark failed (\(errorCodeName)): \(message)") | ||||
|             print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)") | ||||
|         } | ||||
|          | ||||
|         resetUIState() | ||||
|         print("BenchmarkViewModel: Benchmark error (\(errorCode)): \(message)") | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,7 +13,6 @@ import SwiftUI | |||
|  */ | ||||
| struct ModelSelectionCard: View { | ||||
|     @ObservedObject var viewModel: BenchmarkViewModel | ||||
|     @Binding var showStopConfirmation: Bool | ||||
|      | ||||
|     var body: some View { | ||||
|         VStack(alignment: .leading, spacing: 16) { | ||||
|  | @ -140,11 +139,7 @@ struct ModelSelectionCard: View { | |||
|      | ||||
|     private var startStopButton: some View { | ||||
|         Button(action: { | ||||
|             if viewModel.startButtonText.contains("Stop") { | ||||
|                 showStopConfirmation = true | ||||
|             } else { | ||||
|                 viewModel.onStartBenchmarkTapped() | ||||
|             } | ||||
|             viewModel.onStartBenchmarkTapped() | ||||
|         }) { | ||||
|             HStack(spacing: 12) { | ||||
|                 ZStack { | ||||
|  | @ -152,12 +147,12 @@ struct ModelSelectionCard: View { | |||
|                         .fill(Color.white.opacity(0.2)) | ||||
|                         .frame(width: 32, height: 32) | ||||
|                      | ||||
|                     if viewModel.isRunning && viewModel.startButtonText.contains("Stop") { | ||||
|                     if viewModel.isRunning { | ||||
|                         ProgressView() | ||||
|                             .progressViewStyle(CircularProgressViewStyle(tint: .white)) | ||||
|                             .scaleEffect(0.7) | ||||
|                     } else { | ||||
|                         Image(systemName: viewModel.startButtonText.contains("Stop") ? "stop.fill" : "play.fill") | ||||
|                         Image(systemName: viewModel.isRunning ? "stop.fill" : "play.fill") | ||||
|                             .font(.system(size: 16, weight: .bold)) | ||||
|                             .foregroundColor(.white) | ||||
|                     } | ||||
|  | @ -169,7 +164,7 @@ struct ModelSelectionCard: View { | |||
|                  | ||||
|                 Spacer() | ||||
|                  | ||||
|                 if !viewModel.startButtonText.contains("Stop") { | ||||
|                 if !viewModel.isRunning { | ||||
|                     Image(systemName: "arrow.right") | ||||
|                         .font(.system(size: 16, weight: .semibold)) | ||||
|                         .foregroundColor(.white.opacity(0.8)) | ||||
|  | @ -182,7 +177,7 @@ struct ModelSelectionCard: View { | |||
|                 RoundedRectangle(cornerRadius: 16) | ||||
|                     .fill( | ||||
|                         viewModel.isStartButtonEnabled ?  | ||||
|                         (viewModel.startButtonText.contains("Stop") ?  | ||||
|                         (viewModel.isRunning ?  | ||||
|                          LinearGradient( | ||||
|                              colors: [Color.benchmarkError, Color.benchmarkError.opacity(0.8)], | ||||
|                              startPoint: .leading, | ||||
|  | @ -234,8 +229,7 @@ struct ModelSelectionCard: View { | |||
| 
 | ||||
| #Preview { | ||||
|     ModelSelectionCard( | ||||
|         viewModel: BenchmarkViewModel(), | ||||
|         showStopConfirmation: .constant(false) | ||||
|         viewModel: BenchmarkViewModel() | ||||
|     ) | ||||
|     .padding() | ||||
| } | ||||
|  |  | |||
|  | @ -13,7 +13,6 @@ import SwiftUI | |||
|  */ | ||||
| struct BenchmarkView: View { | ||||
|     @StateObject private var viewModel = BenchmarkViewModel() | ||||
|     @State private var showStopConfirmation = false | ||||
|      | ||||
|     var body: some View { | ||||
|         ZStack { | ||||
|  | @ -21,8 +20,7 @@ struct BenchmarkView: View { | |||
|                 VStack(spacing: 24) { | ||||
|                     // Model Selection Section | ||||
|                     ModelSelectionCard( | ||||
|                         viewModel: viewModel, | ||||
|                         showStopConfirmation: $showStopConfirmation | ||||
|                         viewModel: viewModel | ||||
|                     ) | ||||
|                      | ||||
|                     // Progress Section | ||||
|  | @ -55,7 +53,7 @@ struct BenchmarkView: View { | |||
|                 .padding(.vertical, 16) | ||||
|             } | ||||
|         } | ||||
|         .alert("Stop Benchmark", isPresented: $showStopConfirmation) { | ||||
|         .alert("Stop Benchmark", isPresented: $viewModel.showStopConfirmation) { | ||||
|             Button("Yes", role: .destructive) { | ||||
|                 viewModel.onStopBenchmarkTapped() | ||||
|             } | ||||
|  | @ -68,11 +66,7 @@ struct BenchmarkView: View { | |||
|         } message: { | ||||
|             Text(viewModel.errorMessage) | ||||
|         } | ||||
|         .onReceive(viewModel.$isRunning) { isRunning in | ||||
|             if isRunning && viewModel.startButtonText.contains("Stop") { | ||||
|                 showStopConfirmation = false | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -32,7 +32,9 @@ struct LocalModelListView: View { | |||
|                     viewModel.selectModel(model) | ||||
|                 }) { | ||||
|                     LocalModelRowView(model: model) | ||||
|                         .contentShape(Rectangle()) | ||||
|                 } | ||||
|                 .buttonStyle(PlainButtonStyle()) | ||||
|                 .listRowBackground(viewModel.pinnedModelIds.contains(model.id) ? Color.black.opacity(0.05) : Color.clear) | ||||
|                 .swipeActions(edge: .trailing, allowsFullSwipe: false) { | ||||
|                     SwipeActionsView(model: model, viewModel: viewModel) | ||||
|  |  | |||
|  | @ -59,6 +59,10 @@ struct LocalModelRowView: View { | |||
|                     } | ||||
|                 } | ||||
|             } | ||||
|              | ||||
|             Spacer() | ||||
|         } | ||||
|         .padding(.vertical, 8) | ||||
|         .frame(maxWidth: .infinity, alignment: .leading) | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -142,7 +142,10 @@ struct MainTabView: View { | |||
|                         .edgesIgnoringSafeArea(.all) | ||||
|         } | ||||
|         .onChange(of: showHistory) { oldValue, newValue in | ||||
|             if !newValue { | ||||
|             if newValue { | ||||
|                 // Refresh chat history when opening the side menu | ||||
|                 histories = ChatHistoryManager.shared.getAllHistory() | ||||
|             } else { | ||||
|                 DispatchQueue.main.asyncAfter(deadline: .now() + 0.3) { | ||||
|                     withAnimation { | ||||
|                         showHistoryButton = true | ||||
|  | @ -191,6 +194,9 @@ struct MainTabView: View { | |||
|                         modelListViewModel.recordModelUsage(modelName: model.modelName) | ||||
|                     } | ||||
|                      | ||||
|                     // Refresh chat history when returning from chat | ||||
|                     histories = ChatHistoryManager.shared.getAllHistory() | ||||
|                      | ||||
|                     // Clear selections | ||||
|                     modelListViewModel.selectedModel = nil | ||||
|                     selectedHistory = nil | ||||
|  |  | |||
|  | @ -67,7 +67,14 @@ struct ModelInfo: Codable { | |||
|         } | ||||
|          | ||||
|         let sourceKey = ModelSourceManager.shared.selectedSource.rawValue | ||||
|         return sources[sourceKey] ?? "taobao-mnn/\(modelName)" | ||||
|         let baseId = sources[sourceKey] ?? "taobao-mnn/\(modelName)" | ||||
|          | ||||
|         // Add vendor prefix to ensure uniqueness for local models | ||||
|         if vendor == "Local" { | ||||
|             return "local/\(modelName)" | ||||
|         } | ||||
|          | ||||
|         return baseId | ||||
|     } | ||||
|      | ||||
|     var localizedTags: [String] { | ||||
|  |  | |||
|  | @ -72,7 +72,7 @@ class ModelListViewModel: ObservableObject { | |||
|             if !foundModelFiles.isEmpty { | ||||
|                 // Check if we have a complete model (at least config.json) | ||||
|                 if foundModelFiles.contains("config.json") { | ||||
|                     let modelName = "Qwen3-0.6B-MNN" | ||||
|                     let modelName = "Qwen3-0.6B-MNN-Inside" | ||||
|                     let localModel = ModelInfo( | ||||
|                         modelName: modelName, | ||||
|                         tags: [NSLocalizedString("tag.deepThinking", comment: "Deep thinking tag for local model"), | ||||
|  | @ -143,8 +143,10 @@ class ModelListViewModel: ObservableObject { | |||
|                         let itemConfigPath = (itemPath as NSString).appendingPathComponent("config.json") | ||||
|                          | ||||
|                         if fileManager.fileExists(atPath: itemConfigPath) { | ||||
|                             // Use custom name for Qwen3-0.6B-MNN to avoid conflicts | ||||
|                             let modelName = item == "Qwen3-0.6B-MNN" ? "Qwen3-0.6B-MNN-Inside" : item | ||||
|                             let localModel = ModelInfo( | ||||
|                                 modelName: item, | ||||
|                                 modelName: modelName, | ||||
|                                 tags: ["local", "bundled"], | ||||
|                                 categories: ["Local Models"], | ||||
|                                 vendor: "Local", | ||||
|  | @ -153,7 +155,7 @@ class ModelListViewModel: ObservableObject { | |||
|                             ) | ||||
|                             localModels.append(localModel) | ||||
|                              | ||||
|                             ModelStorageManager.shared.markModelAsDownloaded(item) | ||||
|                             ModelStorageManager.shared.markModelAsDownloaded(modelName) | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | @ -175,12 +177,15 @@ class ModelListViewModel: ObservableObject { | |||
|              | ||||
|             var fetchedModels = info.models | ||||
|              | ||||
|             // Add LocalModel folder models | ||||
|             // Add LocalModel folder models, avoiding duplicates | ||||
|             let localModels = await loadLocalModels() | ||||
|             fetchedModels.append(contentsOf: localModels) | ||||
|             let existingModelNames = Set(fetchedModels.map { $0.modelName }) | ||||
|             let uniqueLocalModels = localModels.filter { !existingModelNames.contains($0.modelName) } | ||||
|             fetchedModels.append(contentsOf: uniqueLocalModels) | ||||
|              | ||||
|             filterDiffusionModels(fetchedModels: &fetchedModels) | ||||
|             loadCachedSizes(for: &fetchedModels) | ||||
|             syncDownloadStatus(for: &fetchedModels) | ||||
|             sortModels(fetchedModels: &fetchedModels) | ||||
|             self.models = fetchedModels | ||||
|              | ||||
|  | @ -203,6 +208,18 @@ class ModelListViewModel: ObservableObject { | |||
|         } | ||||
|     } | ||||
|      | ||||
|     private func syncDownloadStatus(for models: inout [ModelInfo]) { | ||||
|         for i in 0..<models.count { | ||||
|             let isDownloaded = ModelStorageManager.shared.isModelDownloaded(models[i].modelName) | ||||
|             models[i].isDownloaded = isDownloaded | ||||
|              | ||||
|             // Also sync last used date | ||||
|             if let lastUsed = ModelStorageManager.shared.getLastUsed(for: models[i].modelName) { | ||||
|                 models[i].lastUsedAt = lastUsed | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     private func fetchModelSizes(for models: [ModelInfo]) async { | ||||
|         await withTaskGroup(of: Void.self) { group in | ||||
|             for (_, model) in models.enumerated() { | ||||
|  |  | |||
|  | @ -117,7 +117,7 @@ static inline uint64_t getTimeInUs() { | |||
| } | ||||
| 
 | ||||
| std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward = MNN_FORWARD_CPU, bool only_inference = true, | ||||
|                            int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false) { | ||||
|                            int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false, bool enableKleidiAI=false) { | ||||
|     auto revertor = std::unique_ptr<Revert>(new Revert(model.model_file.c_str())); | ||||
|     if (testQuantModel) { | ||||
|         revertor->initialize(0, sparseBlockOC, false, true); | ||||
|  | @ -130,6 +130,7 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward | |||
|     auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize), MNN::Interpreter::destroy); | ||||
|     revertor.reset(); | ||||
|     net->setSessionMode(MNN::Interpreter::Session_Release); | ||||
|     net->setSessionHint(MNN::Interpreter::HintMode::CPU_ENABLE_KLEIDIAI, enableKleidiAI); | ||||
|     MNN::ScheduleConfig config; | ||||
|     config.numThread = numberThread; | ||||
|     config.type      = static_cast<MNNForwardType>(forward); | ||||
|  | @ -392,8 +393,9 @@ int main(int argc, const char* argv[]) { | |||
|     int precision = 2; | ||||
|     float sparsity = 0.0f; | ||||
|     int sparseBlockOC = 1; | ||||
|     bool enableKleidiAI = false; | ||||
|     if (argc <= 2) { | ||||
|         std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel]" << std::endl; | ||||
|         std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel] [enableKleidiAI]" << std::endl; | ||||
|         return 1; | ||||
|     } | ||||
|     if (argc >= 3) { | ||||
|  | @ -420,8 +422,11 @@ int main(int argc, const char* argv[]) { | |||
|     if(argc >= 10) { | ||||
|         testQuantizedModel = atoi(argv[9]); | ||||
|     } | ||||
|     if (argc >= 11) { | ||||
|         enableKleidiAI = atoi(argv[10]) > 0 ? true : false; | ||||
|     } | ||||
| 
 | ||||
|     std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << std::endl; | ||||
|     std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << " enableKleidiAI=" << enableKleidiAI << std::endl; | ||||
|     std::vector<Model> models = findModelFiles(argv[1]); | ||||
| 
 | ||||
|     std::cout << "--------> Benchmarking... loop = " << argv[2] << ", warmup = " << warmup << std::endl; | ||||
|  | @ -441,10 +446,10 @@ int main(int argc, const char* argv[]) { | |||
|     } | ||||
| 
 | ||||
|     for (auto& m : models) { | ||||
|         std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false); | ||||
|         std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false, enableKleidiAI); | ||||
|         displayStats(m.name.c_str(), costs, false); | ||||
|         if (testQuantizedModel) { | ||||
|             costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true); | ||||
|             costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true, enableKleidiAI); | ||||
|             displayStats(m.name, costs, 1); | ||||
|         } | ||||
|     } | ||||
|  |  | |||
|  | @ -98,5 +98,5 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | |||
| | MNN_SUPPORT_TRANSFORMER_FUSE | 是否支持Fuse Transformer相关OP实现,默认为 `OFF` | | ||||
| | MNN_BUILD_LLM        | 是否构建基于MNN的llm库和demo,默认为`OFF` | | ||||
| | MNN_BUILD_DIFFUSION  | 是否构建基于MNN的diffusion demo,需要打开MNN_BUILD_OPENCV和MNN_IMGCODECS宏使用 默认为`OFF` | | ||||
| | MNN_KLEIDIAI         | 是否集成ARM的klediAI加速库【目前处于实验状态,只能跑对称量化的LLM模型】,默认为`OFF` | | ||||
| | MNN_KLEIDIAI         | 是否集成ARM的klediAI加速库,默认为`ON` | | ||||
| | MNN_USE_RVV          | 是否启用RISC-V向量扩展支持,默认为`OFF` | | ||||
|  |  | |||
|  | @ -252,7 +252,10 @@ public: | |||
|         CPU_CORE_IDS = 14, | ||||
| 
 | ||||
|         // set CPU threads to use when supports Arm sme2
 | ||||
|         CPU_SME2_INSTRUCTIONS = 15 | ||||
|         CPU_SME2_INSTRUCTIONS = 15, | ||||
| 
 | ||||
|         // Enable KleidiAI
 | ||||
|         CPU_ENABLE_KLEIDIAI = 16 | ||||
|     }; | ||||
| 
 | ||||
|     enum ExternalPathType { | ||||
|  |  | |||
|  | @ -21,10 +21,6 @@ | |||
| #include "ThreadPool.hpp" | ||||
| #endif | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include "arm/mnn_kleidiai.h" | ||||
| #endif | ||||
| 
 | ||||
| namespace MNN { | ||||
| class WorkerThread; | ||||
| class CPURuntime : public Runtime { | ||||
|  |  | |||
|  | @ -19,8 +19,7 @@ if (MNN_CPU_WEIGHT_DEQUANT_GEMM) | |||
|     FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/normal_memory/*.[sS]) | ||||
| endif() | ||||
| 
 | ||||
| if (MNN_KLEIDIAI) | ||||
|     add_definitions(-DMNN_KLEIDIAI_ENABLED=1) | ||||
| if (MNN_KLEIDIAI AND CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64") | ||||
|     # Disable the KleidiAI tests | ||||
|     set(KLEIDIAI_BUILD_TESTS  OFF) | ||||
|     # Fetch KleidiAI sources: | ||||
|  | @ -121,6 +120,9 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?") | |||
|     endif() | ||||
| elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64") | ||||
|     message(STATUS "Enabling AArch64 Assemblies") | ||||
|     if (MNN_KLEIDIAI) | ||||
|        add_definitions(-DMNN_KLEIDIAI_ENABLED=1) | ||||
|     endif() | ||||
|     if (MNN_SME2) | ||||
|         add_definitions(-DMNN_SME2) | ||||
|         FILE(GLOB MNN_SME2_AArch64_SRC ${MNN_SME2_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/sme2_asm/*.[sS]) | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ | |||
| #include "core/OpCommonUtils.hpp" | ||||
| #include "backend/cpu/OneDNNConvolution.hpp" | ||||
| #include "backend/cpu/compute/ConvInt8TiledExecutor.hpp" | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include "backend/cpu/compute/KleidiAIConvInt8.hpp" | ||||
| #include "backend/cpu/compute/KleidiAIConvolution.hpp" | ||||
|  | @ -31,6 +32,92 @@ | |||
| 
 | ||||
| namespace MNN { | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| static Execution* _createKleidiAIUnit(const Tensor* input, const Tensor* output, Backend* backend, const Op* op, | ||||
|                                       const float* originWeight, size_t originWeightSize, const float* bias, | ||||
|                                       size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo, | ||||
|                                       bool supportSparse, bool lowMemory) { | ||||
|     auto cpuBackend = (CPUBackend*)backend; | ||||
|     auto conv2d     = op->main_as_Convolution2D(); | ||||
|     auto common     = conv2d->common(); | ||||
| 
 | ||||
|     bool fastWay = common->kernelY() == 1 && common->kernelX() == 1 && output->width() == input->width() && | ||||
|                    output->height() == input->height() && common->strideX() == 1 && common->strideY() == 1; | ||||
| 
 | ||||
| #ifdef MNN_LOW_MEMORY | ||||
|     if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) { | ||||
|         if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) { | ||||
|             do { | ||||
|                 if (!weightQuantInfo->canUseInt4) { | ||||
|                     break; | ||||
|                 } | ||||
|                 auto convOp = op->main_as_Convolution2D(); | ||||
|                 auto core   = static_cast<CPUBackend*>(backend)->functions(); | ||||
|                 int oc      = convOp->common()->outputCount(); | ||||
|                 int ic      = convOp->common()->inputCount(); | ||||
| 
 | ||||
|                 int blockNum   = 1; | ||||
|                 int dequantCnt = weightQuantInfo->alphaSize; | ||||
|                 if (weightQuantInfo->asymmetric) { | ||||
|                     dequantCnt /= 2; | ||||
|                 } | ||||
|                 blockNum = dequantCnt / oc; | ||||
| 
 | ||||
|                 bool bAsym     = weightQuantInfo->asymmetric; | ||||
|                 size_t blkSize = blockNum == 1 ? 0 : ic / blockNum; | ||||
| 
 | ||||
|                 KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes); | ||||
| 
 | ||||
|                 KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo()); | ||||
|                 if (!kai.canAccelerate(accelType, convOp->common())) { | ||||
|                     break; | ||||
|                 } | ||||
| 
 | ||||
|                 if (!kai.isLoaded(accelType)) { | ||||
|                     kai.setLoaded(accelType); | ||||
|                     kai.printInfo(accelType); | ||||
|                 } | ||||
| 
 | ||||
|                 return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum); | ||||
|             } while (0); | ||||
|         } | ||||
| 
 | ||||
|         // Have not supported the quantized weight.
 | ||||
|         return nullptr; | ||||
|     } | ||||
| #else | ||||
|     if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) { | ||||
|         if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) { | ||||
|             return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, | ||||
|                                                 weightQuantInfo); | ||||
|         } | ||||
| 
 | ||||
|         // Do nothing and fallback.
 | ||||
|         return nullptr; | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
|     // This is different with original impl. It's a corresponding impl for strassen,
 | ||||
|     // which is called when built without MNN_REDUCE_SIZE. But for KleidiAI,
 | ||||
|     // need not to care about this.
 | ||||
|     if (fastWay && cpuBackend->functions()->matmulBytes == 0) { | ||||
|         auto bytes     = cpuBackend->functions()->bytes; | ||||
|         auto accelType = (bytes == 2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32; | ||||
|         KleidiAI& kai  = KleidiAI::getInstance(*MNNGetCPUInfo()); | ||||
|         if (kai.canAccelerate(accelType)) { | ||||
|             return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) { | ||||
|         return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, | ||||
|                                             weightQuantInfo); | ||||
|     } | ||||
| 
 | ||||
|     return nullptr; | ||||
| } | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
| 
 | ||||
| static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend* backend, | ||||
|                               const Op* op, const float* originWeight, size_t originWeightSize, const float* bias, size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo, bool supportSparse, bool lowMemory) { | ||||
|     auto cpuBackend = (CPUBackend*)backend; | ||||
|  | @ -48,47 +135,24 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend | |||
|         } | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
|     if (cpuBackend->getRuntime()->hint().enableKleidiAI) { | ||||
|         auto execution = _createKleidiAIUnit(input, output, backend, op, originWeight, originWeightSize, bias, biasSize, | ||||
|                                              weightQuantInfo, supportSparse, lowMemory); | ||||
| 
 | ||||
|         if (execution) { | ||||
|             return execution; | ||||
|         } | ||||
|     } | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
| 
 | ||||
|     bool fastWay = common->kernelY() == 1 && common->kernelX() == 1 | ||||
|         && output->width() == input->width() && output->height() == input->height() | ||||
|         && common->strideX() == 1 && common->strideY() == 1; | ||||
| #ifdef MNN_LOW_MEMORY | ||||
|     if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) { | ||||
|         if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) { | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
|             do { | ||||
|                 if (!weightQuantInfo->canUseInt4) { | ||||
|                     break; | ||||
|                 } | ||||
|                 auto convOp = op->main_as_Convolution2D(); | ||||
|                 auto core = static_cast<CPUBackend*>(backend)->functions(); | ||||
|                 int oc = convOp->common()->outputCount(); | ||||
|                 int ic = convOp->common()->inputCount(); | ||||
| 
 | ||||
|                 int blockNum = 1; | ||||
|                 int dequantCnt = weightQuantInfo->alphaSize; | ||||
|                 if (weightQuantInfo->asymmetric) { | ||||
|                     dequantCnt /= 2; | ||||
|                 } | ||||
|                 blockNum = dequantCnt / oc; | ||||
| 
 | ||||
|                 bool bAsym = weightQuantInfo->asymmetric; | ||||
|                 size_t blkSize = blockNum == 1 ? 0 : ic / blockNum; | ||||
| 
 | ||||
|                 KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes); | ||||
| 
 | ||||
|                 KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo()); | ||||
|                 if(!kai.isLoaded(accelType)) { | ||||
|                     kai.setLoaded(accelType); | ||||
|                     kai.printInfo(accelType); | ||||
|                 } | ||||
| 
 | ||||
|                 if(!kai.canAccelerate(accelType, convOp->common())){ | ||||
|                     break; | ||||
|                 } | ||||
|                 return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum); | ||||
|             } while (0); | ||||
| #endif | ||||
| 
 | ||||
|             return new DenseConvInt8TiledExecutor(backend, op, weightQuantInfo, true); | ||||
|         } else { | ||||
|             return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); | ||||
|  | @ -96,37 +160,16 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend | |||
|     } | ||||
| #else | ||||
|     if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) { | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| 	if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) { | ||||
| 	    return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); | ||||
| 	} | ||||
| #endif | ||||
| 
 | ||||
|         return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
| #ifndef MNN_REDUCE_SIZE | ||||
|     if (fastWay && cpuBackend->functions()->matmulBytes == 0) { | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
|         auto bytes = cpuBackend->functions()->bytes;  | ||||
|         auto accelType = (bytes==2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32; | ||||
|         KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo()); | ||||
|         if (kai.canAccelerate(accelType)){ | ||||
|             return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize); | ||||
|         } | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
| 
 | ||||
|         return new Convolution1x1Strassen(common, backend, originWeight, originWeightSize, bias, biasSize); | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
|     if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) { | ||||
| 	return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
|     if (cpuBackend->getRuntime()->hint().winogradMemoryUsed == 0 || (!ConvolutionWinogradBridge::canUseWinograd(common))) { | ||||
|         return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, nullptr); | ||||
|     } | ||||
|  |  | |||
|  | @ -303,4 +303,4 @@ ErrorCode KleidiAIConvInt8::onExecute(const std::vector<Tensor*>& inputs, const | |||
| } | ||||
| 
 | ||||
| } // namespace MNN
 | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
|  |  | |||
|  | @ -6,10 +6,8 @@ | |||
| 
 | ||||
| #ifndef KleidiAIConvInt8_hpp | ||||
| #define KleidiAIConvInt8_hpp | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include "backend/cpu/CPUConvolution.hpp" | ||||
| #include "Int8FunctionsOpt.h" | ||||
| #include "CommonOptFunction.h" | ||||
| #include "backend/cpu/arm/mnn_kleidiai.h" | ||||
| 
 | ||||
| namespace MNN { | ||||
| class KleidiAIConvInt8 : public CPUConvolution { | ||||
|  | @ -31,5 +29,4 @@ private: | |||
| }; | ||||
| 
 | ||||
| } // namespace MNN
 | ||||
| #endif // MNN_KLEIDIAI_ENABLED
 | ||||
| #endif /* KleidiAIConvInt8_hpp */ | ||||
| #endif /* KleidiAIConvInt8_hpp */ | ||||
|  |  | |||
|  | @ -3,19 +3,15 @@ | |||
| //
 | ||||
| // SPDX-License-Identifier: Apache-2.0
 | ||||
| //
 | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include "KleidiAIConvolution.hpp" | ||||
| #include <string.h> | ||||
| #include "core/BufferAllocator.hpp" | ||||
| #include "backend/cpu/CPUBackend.hpp" | ||||
| #include "core/Concurrency.h" | ||||
| #include "core/TensorUtils.hpp" | ||||
| #include "backend/cpu/CPUTensorConvert.hpp" | ||||
| 
 | ||||
| namespace MNN { | ||||
| #ifndef MNN_REDUCE_SIZE | ||||
| 
 | ||||
| KleidiAIConvolution::KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight, | ||||
|                                         size_t originWeightSize, const float *bias, size_t biasSize) | ||||
|     : CPUConvolution(common, b) { | ||||
|  | @ -228,7 +224,5 @@ ErrorCode KleidiAIConvolution::onExecute(const std::vector<Tensor *> &inputs, co | |||
| 
 | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| } // namespace MNN
 | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
|  |  | |||
|  | @ -6,12 +6,9 @@ | |||
| 
 | ||||
| #ifndef KleidiAIConvolution_hpp | ||||
| #define KleidiAIConvolution_hpp | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include <functional> | ||||
| #include "backend/cpu/CPUConvolution.hpp" | ||||
| #include "backend/cpu/arm/mnn_kleidiai.h" | ||||
| namespace MNN { | ||||
| #ifndef MNN_REDUCE_SIZE | ||||
| 
 | ||||
| class KleidiAIConvolution : public CPUConvolution{ | ||||
|     public: | ||||
|         KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight, size_t originWeightSize, const float *bias, size_t biasSize); | ||||
|  | @ -30,8 +27,5 @@ class KleidiAIConvolution : public CPUConvolution{ | |||
|         KleidiAI::AccelType mAccelType = KleidiAI::AccelType::ACC_TYPE_NUMBER; | ||||
|         std::vector<float> mPostParameters; | ||||
| }; | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
| 
 | ||||
| } // namespace MNN
 | ||||
| #endif | ||||
| #endif /* KleidiAIConvolution_hpp */ | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| #if MNN_KLEIDIAI_ENABLED | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include "KleidiAIDenseConvolution.hpp" | ||||
| 
 | ||||
| #include <numeric> | ||||
|  | @ -9,6 +9,7 @@ | |||
| #include "backend/cpu/CPUTensorConvert.hpp" | ||||
| #include "core/Macro.h" | ||||
| #include "core/TensorUtils.hpp" | ||||
| #include "core/Concurrency.h" | ||||
| #include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" | ||||
| #include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" | ||||
| #include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" | ||||
|  | @ -304,10 +305,15 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp | |||
|         .dilatedWidth  = mCommon->dilateX(), | ||||
|     }; | ||||
| 
 | ||||
|     mFunction.second = [=](int tid) { | ||||
|     int threadNum = static_cast<CPUBackend*>(backend())->threadNumber(); | ||||
|     mFunction.second = [=](int tId) { | ||||
|         // Convert NC4HW4 to NHWC
 | ||||
|         auto inputShape = input->shape(); // TODO check for NC4HW4, should be the NCHW
 | ||||
|         CPUTensorConverter::convert(input, &mInputNHWC, core); | ||||
|         // CPUTensorConverter::convert(input, &mInputNHWC, core);
 | ||||
|         MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|             CPUTensorConverter::convert(input, &mInputNHWC, core, tId, threadNum); | ||||
|         }; | ||||
|         MNN_CONCURRENCY_END(); | ||||
|         // Lhs packing
 | ||||
|         if (bytes == 4) { | ||||
|             int blockSize = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme(); | ||||
|  | @ -348,7 +354,11 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp | |||
|         } | ||||
| 
 | ||||
|         // Convert NHWC to NC4HW4
 | ||||
|         CPUTensorConverter::convert(&mOutputNHWC, output, core); | ||||
|         // CPUTensorConverter::convert(&mOutputNHWC, output, core);
 | ||||
|         MNN_CONCURRENCY_BEGIN(tId, threadNum) { | ||||
|             CPUTensorConverter::convert(&mOutputNHWC, output, core, tId, threadNum); | ||||
|         }; | ||||
|         MNN_CONCURRENCY_END(); | ||||
|     }; | ||||
|     return NO_ERROR; | ||||
| } | ||||
|  | @ -359,4 +369,4 @@ ErrorCode KleidiAIDenseConvolutionImpl::onExecute(const std::vector<Tensor*>& in | |||
|     return NO_ERROR; | ||||
| } | ||||
| } // namespace MNN
 | ||||
| #endif | ||||
| #endif //MNN_KLEIDIAI_ENABLED
 | ||||
|  |  | |||
|  | @ -1,5 +1,3 @@ | |||
| #if MNN_KLEIDIAI_ENABLED | ||||
| 
 | ||||
| #ifndef KleidiAIDenseConvolution_hpp | ||||
| #define KleidiAIDenseConvolution_hpp | ||||
| 
 | ||||
|  | @ -242,4 +240,3 @@ private: | |||
| } // namespace MNN
 | ||||
| 
 | ||||
| #endif /* KleidiAIDenseConvolution_hpp */ | ||||
| #endif | ||||
|  |  | |||
|  | @ -62,6 +62,8 @@ struct RuntimeHint { | |||
|     // whether to use Arm sme2 cores when threads>1
 | ||||
|     bool useArmSme2Cores = true; | ||||
| 
 | ||||
|     bool enableKleidiAI = false; | ||||
| 
 | ||||
|     // Use CPU Ids
 | ||||
|     std::vector<int> cpuIds; | ||||
| }; | ||||
|  |  | |||
|  | @ -109,6 +109,9 @@ void Session::ModeGroup::setHint(Interpreter::HintMode hint, int value) { | |||
|         case Interpreter::HintMode::INIT_THREAD_NUMBER: | ||||
|             runtimeHint.initThreadNumber = value; | ||||
|             break; | ||||
|         case Interpreter::HintMode::CPU_ENABLE_KLEIDIAI: | ||||
|             runtimeHint.enableKleidiAI = value > 0 ? true : false; | ||||
|             break; | ||||
|         default: | ||||
|             break; | ||||
|     } | ||||
|  |  | |||
|  | @ -19,10 +19,6 @@ | |||
| #undef CONSTANT | ||||
| #endif // CONSTANT
 | ||||
| 
 | ||||
| #ifdef MNN_KLEIDIAI_ENABLED | ||||
| #include "../backend/cpu/arm/mnn_kleidiai.h" | ||||
| #endif | ||||
| 
 | ||||
| namespace MNN { | ||||
| struct TensorArrayAttr { | ||||
|     // array size is dynamic or not
 | ||||
|  |  | |||
|  | @ -92,3 +92,7 @@ MNNForwardType getCurrentType() { | |||
|     return attr->firstType; | ||||
| } | ||||
| 
 | ||||
| std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor() { | ||||
|     auto attr = MNN::Express::ExecutorScope::Current()->getAttr(); | ||||
|     return MNN::Express::Executor::newExecutor(getCurrentType(), attr->config, attr->numThread); | ||||
| } | ||||
|  |  | |||
|  | @ -104,6 +104,8 @@ inline float keepFP32Precision(float fp32Value) { | |||
| } | ||||
| MNNForwardType getCurrentType(); | ||||
| 
 | ||||
| std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor(); | ||||
| 
 | ||||
| using ConvertFP32 = float(*)(float fp32Value); | ||||
| 
 | ||||
| const static std::vector<ConvertFP32> FP32Converter = { | ||||
|  |  | |||
|  | @ -49,6 +49,12 @@ public: | |||
|         for (int i = 0; i < channel * height * width; ++i){ | ||||
|             inputData[i] = (rand() % 10) * 0.1; | ||||
|         } | ||||
| 
 | ||||
|         MNN::BackendConfig config; | ||||
|         config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal; | ||||
|         config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal; | ||||
|         std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4)); | ||||
|         ExecutorScope scope(executor); | ||||
|          | ||||
|         auto net = _createModel(); | ||||
|         auto x = _Input({1, channel, height, width}, NCHW, halide_type_of<float>()); | ||||
|  | @ -65,11 +71,7 @@ public: | |||
|          | ||||
|          | ||||
|         // clone model
 | ||||
|         MNN::BackendConfig config; | ||||
|         config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal; | ||||
|         config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal; | ||||
|         std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4)); | ||||
|         ExecutorScope scope(executor); | ||||
| 
 | ||||
|         std::unique_ptr<Module> tempModule(Module::clone(net.get())); | ||||
|          | ||||
|         auto xClone = _Input({1, channel, height, width}, NCHW, halide_type_of<float>()); | ||||
|  |  | |||
|  | @ -14,12 +14,16 @@ | |||
| #include <MNN/expr/Module.hpp> | ||||
| #include "MNNTestSuite.h" | ||||
| #include "MNN_generated.h" | ||||
| #include "TestUtils.h" | ||||
| using namespace MNN; | ||||
| using namespace MNN::Express; | ||||
| 
 | ||||
| class GatherExprTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
| 
 | ||||
|         std::unique_ptr<MNN::OpT> gatherOp(new MNN::OpT); | ||||
|         gatherOp->type = MNN::OpType_GatherND; | ||||
|         auto parameter = _Input({2, 2}, NHWC, halide_type_of<int32_t>()); | ||||
|  | @ -224,7 +228,8 @@ public: | |||
| class GatherNdReComputeTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|          | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         const float inpudata[]                  = {-1.0, -2.0, 3.0, 4.0}; | ||||
|         const int indices_data[]                = {0, 0, 1, 1}; | ||||
|         auto params                             = _Const(inpudata, {2, 2}, NHWC, halide_type_of<float>()); | ||||
|  |  | |||
|  | @ -22,6 +22,8 @@ public: | |||
|             MNN_ERROR("Currently don't test not cpu mmap\n"); | ||||
|             return true; | ||||
|         } | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         auto x = _Input({1, 3, 224, 224}, NC4HW4, halide_type_of<float>()); | ||||
|         x->setName("x"); | ||||
|         auto y = _Conv(1.0f, 0.01f, x, {3, 16}, {5, 5}); | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ | |||
| #include <random> | ||||
| #include "MNNTestSuite.h" | ||||
| #include "MNN_generated.h" | ||||
| #include "TestUtils.h" | ||||
| #include <MNN/expr/Expr.hpp> | ||||
| #include <MNN/expr/ExprCreator.hpp> | ||||
| #include <MNN/expr/Module.hpp> | ||||
|  | @ -66,6 +67,8 @@ static void _originMatMul(float* C, const float* A, const float* B, int e, int l | |||
| class MatMulTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         int e = 5, h = 4, l = 6; | ||||
|         if (true) { | ||||
|             // Test MatMul
 | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ | |||
| #include <MNN/expr/ExprCreator.hpp> | ||||
| #include <MNN/expr/Module.hpp> | ||||
| #include "MNNTestSuite.h" | ||||
| #include "TestUtils.h" | ||||
| using namespace MNN; | ||||
| using namespace MNN::Express; | ||||
| 
 | ||||
|  | @ -15,6 +16,8 @@ public: | |||
|         return summer; | ||||
|     } | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         std::vector<VARP> empty; | ||||
|         // Make Net
 | ||||
|         auto x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>()); | ||||
|  |  | |||
|  | @ -177,6 +177,8 @@ MNNTestSuiteRegister(ModuleTest, "expr/ModuleTest"); | |||
| class ModuleWrongInputTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         std::vector<int8_t> buffer; | ||||
|         // construct
 | ||||
|         { | ||||
|  | @ -244,6 +246,8 @@ MNNTestSuiteRegister(ModuleWrongInputTest, "expr/ModuleWrongInputTest"); | |||
| class RefTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         std::vector<int8_t> buffer; | ||||
|         // construct
 | ||||
|         { | ||||
|  | @ -318,6 +322,8 @@ public: | |||
|         } | ||||
|     } | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         std::vector<int8_t> buffer; | ||||
| #ifdef MNN_REDUCE_SIZE | ||||
|         return true; | ||||
|  | @ -1039,6 +1045,8 @@ MNNTestSuiteRegister(ConstMemoryReplaceTest, "expr/ConstMemoryReplaceTest"); | |||
| class MutlThreadConstReplaceTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         auto func = [precision](VARP y, int thread) { | ||||
|             flatbuffers::FlatBufferBuilder builderOutput(1024); | ||||
|             { | ||||
|  | @ -1499,6 +1507,8 @@ MNNTestSuiteRegister(ExecutorResetLoadModuleTest, "expr/ExecutorResetLoadModuleT | |||
| class SequenceForwardResizeTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         // Make Model include convolution in shape compute and content compute
 | ||||
|         auto x = _Input({1, 3, 24, 24}, NCHW, halide_type_of<float>()); | ||||
|         x->setName("x"); | ||||
|  | @ -1606,6 +1616,8 @@ MNNTestSuiteRegister(SequenceForwardResizeTest, "expr/SequenceForwardResizeTest" | |||
| class InputModuleTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         auto y = _mobileNetV1Expr(nullptr, false); | ||||
|         std::unique_ptr<MNN::NetT> net(new NetT); | ||||
|         Variable::save({y}, net.get()); | ||||
|  |  | |||
|  | @ -35,6 +35,8 @@ static std::shared_ptr<Module> _createModel() { | |||
| class RasterOutputTest : public MNNTestCase { | ||||
| public: | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         auto net = _createModel(); | ||||
|         auto x = _Input({1, 3, 224, 224}, NCHW, halide_type_of<int>()); | ||||
|         auto y = _Transpose(x, {0, 1, 3, 2}); | ||||
|  |  | |||
|  | @ -66,6 +66,11 @@ int main(int argc, char* argv[]) { | |||
|         dynamicOption = atoi(argv[7]); | ||||
|         FUNC_PRINT(dynamicOption); | ||||
|     } | ||||
|     bool enableKleidiAI = false; | ||||
|     if (argc > 8) { | ||||
|         enableKleidiAI = atoi(argv[8]) > 0 ? true : false; | ||||
|         FUNC_PRINT(enableKleidiAI); | ||||
|     } | ||||
|     auto exe = MNN::Express::Executor::newExecutor(type, config, thread); | ||||
|     if (exe == nullptr) { | ||||
|         MNN_ERROR("Can't create executor with type:%d, exit!\n", type); | ||||
|  | @ -76,6 +81,7 @@ int main(int argc, char* argv[]) { | |||
|     // set hint
 | ||||
|     MNN::RuntimeHint hint; | ||||
|     hint.dynamicQuantOption = dynamicOption; | ||||
|     hint.enableKleidiAI = enableKleidiAI; | ||||
|     scope.Current()->getRuntime().second->setRuntimeHint(hint); | ||||
|     MNNTestSuite::get()->pStaus.memory = memory; | ||||
|     MNNTestSuite::get()->pStaus.precision = precision; | ||||
|  |  | |||
|  | @ -202,6 +202,8 @@ class BinaryBroadcastTest : public MNNTestCase { | |||
|     virtual ~BinaryBroadcastTest() = default; | ||||
| 
 | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         bool resultNCHW = testDimensionFormat(NCHW, precision); | ||||
|         bool resultNHWC = testDimensionFormat(NHWC, precision); | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,6 +17,8 @@ class ReverseTest : public MNNTestCase { | |||
| public: | ||||
|     virtual ~ReverseTest() = default; | ||||
|     virtual bool run(int precision) { | ||||
|         auto executor = cloneCurrentExecutor(); | ||||
|         ExecutorScope scope(executor); | ||||
|         std::shared_ptr<MNN::Express::Module> net; | ||||
|         { | ||||
|             auto input = _Input({3, 2, 3}, NCHW); | ||||
|  |  | |||
|  | @ -109,7 +109,7 @@ static inline std::vector<int> parseIntList(const std::string& str, char delim) | |||
| int main(int argc, char *argv[]) { | ||||
|     if (argc < 3) { | ||||
|         MNN_PRINT("=======================================================================================================================================\n"); | ||||
|         MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds]\n"); | ||||
|         MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds] [enableKleidiAI]\n"); | ||||
|         MNN_PRINT("=======================================================================================================================================\n"); | ||||
|         return 0; | ||||
|     } | ||||
|  | @ -247,11 +247,16 @@ int main(int argc, char *argv[]) { | |||
|     for (auto id : cpuIds) { | ||||
|         MNN_PRINT("%d ", id); | ||||
|     } | ||||
|     bool enableKleidiAI = false; | ||||
|     if (argc > 10) { | ||||
|         enableKleidiAI = atoi(argv[10]) > 0 ? true : false; | ||||
|     } | ||||
|     MNN_PRINT("\n"); | ||||
|     FUNC_PRINT(precision); | ||||
|     FUNC_PRINT(memory); | ||||
|     FUNC_PRINT(power); | ||||
|     FUNC_PRINT_ALL(cacheFileName, s); | ||||
|     FUNC_PRINT(enableKleidiAI); | ||||
|     // create session
 | ||||
|     MNN::ScheduleConfig config; | ||||
|     config.type      = type; | ||||
|  | @ -320,6 +325,10 @@ int main(int argc, char *argv[]) { | |||
|         rtmgr->setHint(Interpreter::DYNAMIC_QUANT_OPTIONS, 2); | ||||
|     } | ||||
| 
 | ||||
|     if (enableKleidiAI) { | ||||
|         rtmgr->setHint(Interpreter::CPU_ENABLE_KLEIDIAI, true); | ||||
|     } | ||||
| 
 | ||||
|     // rtmgr->setHint(Interpreter::CPU_SME2_INSTRUCTIONS, false);
 | ||||
| 
 | ||||
|     if (runMask & 2048) { | ||||
|  |  | |||
|  | @ -645,18 +645,21 @@ VARP Omni::gen_position_ids(int seq_len) { | |||
|         positionIds = _Input({3, seq_len}, NCHW, halide_type_of<int>()); | ||||
|     } | ||||
|     auto ptr = positionIds->writeMap<int>(); | ||||
|     if (mContext->gen_seq_len > 0) { | ||||
|         for (int i=0; i<seq_len; ++i) { | ||||
|             auto pos = mContext->gen_seq_len + mPositionIds.back() + i; | ||||
|             ptr[i + 0] = pos; | ||||
|             ptr[i + seq_len] = pos; | ||||
|             ptr[i + seq_len * 2] = pos; | ||||
|         } | ||||
|     if (seq_len == 1) { | ||||
|         ptr[0] = mContext->gen_seq_len + mPositionIds.back(); | ||||
|         ptr[1] = ptr[0]; | ||||
|         ptr[2] = ptr[0]; | ||||
|     } else { | ||||
|         for (int i = 0; i < seq_len; i++) { | ||||
|             ptr[i] = mPositionIds.mT[i]; | ||||
|             ptr[i + seq_len] = mPositionIds.mH[i]; | ||||
|             ptr[i + seq_len * 2] = mPositionIds.mW[i]; | ||||
|             if (mPositionIds.mT.size() != seq_len) { | ||||
|                 ptr[i] = i; | ||||
|                 ptr[i + seq_len] = i; | ||||
|                 ptr[i + seq_len * 2] = i; | ||||
|             } else { | ||||
|                 ptr[i] = mPositionIds.mT[i]; | ||||
|                 ptr[i + seq_len] = mPositionIds.mH[i]; | ||||
|                 ptr[i + seq_len * 2] = mPositionIds.mW[i]; | ||||
|             } | ||||
|         } | ||||
|         if (mTalker) { | ||||
|             mTalker->setPostionIds(mPositionIds); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue