Compare commits

...

11 Commits

Author SHA1 Message Date
Jenni Vandervort 61fcaf4472
Merge 30b463fe52 into f0742ad340 2025-06-27 13:54:14 +08:00
jxt1234 f0742ad340
Merge pull request #3658 from jules-ai/sync_ForwardType_schema
android / android_build (push) Has been cancelled Details
ios / ios_build (push) Has been cancelled Details
linux / linux_buil_test (push) Has been cancelled Details
macos / macos_buil_test (push) Has been cancelled Details
windows / windows_build_test (push) Has been cancelled Details
Sync forward type schema
2025-06-26 21:10:10 +08:00
jxt1234 2bb4b5b9c9
Merge pull request #3662 from jules-ai/refresh_OneDNNConvInt8
Refresh OneDNNConvInt8 code
2025-06-26 21:08:20 +08:00
jxt1234 b28d84cd84
Merge pull request #3669 from sunshine0523/master
Fixed: Sherpa-MNN could not load TTS models
2025-06-26 21:07:19 +08:00
KindBrave 998a4200bc refactor(sherpa-mnn): update offline tts models to use model path instead of data
- Change OfflineTtsKokoroModelConfig, OfflineTtsMatchaModelConfig, and OfflineTtsVitsModelConfig to use model path instead of model data- Update offline-tts.cc to reflect the changes in model configuration
- Modify Init functions in offline-tts-kokoro-model.cc, offline-tts-matcha-model.cc, and offline-tts-vits-model.cc to accept model path instead of model data
- Update MNN::Express::Module::load function calls to use model path instead of model data
2025-06-26 11:29:26 +08:00
Jules cf39606610 update cmake minimium support version of dnnl 2025-06-24 01:46:57 +00:00
Jules 38c15d853b align OneDNNConvInt8 with updated function signature 2025-06-24 01:46:34 +00:00
Jules d9891d9424 update new MNN_gennerated.h 2025-06-23 02:52:56 +00:00
Jules edf1b3555c update cmake minimium support version 2025-06-23 02:34:02 +00:00
Jules fdb1ce073e sync ForwardType with current code 2025-06-23 02:29:07 +00:00
shunlibest 30b463fe52 fix(ModelDownloadManager): Update download progress check method
Change the download progress check from `getDownloadSizeTotal` to `getDownloadSizeSaved` to more accurately reflect the saved download data size.
2025-06-11 16:58:53 +08:00
12 changed files with 53 additions and 34 deletions

View File

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 2.8)
cmake_minimum_required(VERSION 3.6)
# generate compile_commands.json
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
include(CheckCXXSymbolExists)

View File

@ -181,7 +181,7 @@ class ModelDownloadManager private constructor(private val context: Context) {
if (getDownloadedFile(modelId) != null) {
downloadInfo.downlodaState = DownloadInfo.DownloadSate.COMPLETED
downloadInfo.progress = 1.0
} else if (getDownloadSizeTotal(ApplicationProvider.get(), modelId) > 0) {
} else if (getDownloadSizeSaved(ApplicationProvider.get(), modelId) > 0) {
val totalSize = getDownloadSizeTotal(ApplicationProvider.get(), modelId)
val savedSize = getDownloadSizeSaved(ApplicationProvider.get(), modelId)
downloadInfo.totalSize = totalSize

View File

@ -32,9 +32,9 @@ class OfflineTtsKokoroModel::Impl {
: config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto model_buf = ReadFile(config.kokoro.model);
auto model_path = config.kokoro.model.c_str();
auto voices_buf = ReadFile(config.kokoro.voices);
Init(model_buf.data(), model_buf.size(), voices_buf.data(),
Init(model_path, voices_buf.data(),
voices_buf.size());
}
@ -43,9 +43,9 @@ class OfflineTtsKokoroModel::Impl {
: config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto model_buf = ReadFile(mgr, config.kokoro.model);
auto model_path = config.kokoro.model.c_str();
auto voices_buf = ReadFile(mgr, config.kokoro.voices);
Init(model_buf.data(), model_buf.size(), voices_buf.data(),
Init(model_path, voices_buf.data(),
voices_buf.size());
}
@ -96,9 +96,9 @@ class OfflineTtsKokoroModel::Impl {
}
private:
void Init(void *model_data, size_t model_data_length, const char *voices_data,
void Init(const char *model_path, const char *voices_data,
size_t voices_data_length) {
sess_ = std::unique_ptr<MNN::Express::Module>(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length,
sess_ = std::unique_ptr<MNN::Express::Module>(MNN::Express::Module::load({}, {}, model_path,
sess_opts_.pManager, &sess_opts_.pConfig));
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

View File

@ -31,8 +31,8 @@ class OfflineTtsMatchaModel::Impl {
: config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config.matcha.acoustic_model);
Init(buf.data(), buf.size());
auto model_path = config.matcha.acoustic_model.c_str();
Init(model_path);
}
template <typename Manager>
@ -40,8 +40,8 @@ class OfflineTtsMatchaModel::Impl {
: config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config.matcha.acoustic_model);
Init(buf.data(), buf.size());
auto model_path = config.matcha.acoustic_model.c_str();
Init(model_path);
}
const OfflineTtsMatchaModelMetaData &GetMetaData() const {
@ -117,8 +117,8 @@ class OfflineTtsMatchaModel::Impl {
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::unique_ptr<MNN::Express::Module>(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length,
void Init(const char *model_path) {
sess_ = std::unique_ptr<MNN::Express::Module>(MNN::Express::Module::load({}, {}, model_path,
sess_opts_.pManager, &sess_opts_.pConfig));
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

View File

@ -31,8 +31,8 @@ class OfflineTtsVitsModel::Impl {
: config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config.vits.model);
Init(buf.data(), buf.size());
auto model_path = config.vits.model.c_str();
Init(model_path);
}
template <typename Manager>
@ -40,8 +40,8 @@ class OfflineTtsVitsModel::Impl {
: config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config.vits.model);
Init(buf.data(), buf.size());
auto model_path = config.vits.model.c_str();
Init(model_path);
}
MNN::Express::VARP Run(MNN::Express::VARP x, int sid, float speed) {
@ -114,8 +114,8 @@ class OfflineTtsVitsModel::Impl {
const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::unique_ptr<MNN::Express::Module>(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length,
void Init(const char *model_path) {
sess_ = std::unique_ptr<MNN::Express::Module>(MNN::Express::Module::load({}, {}, model_path,
sess_opts_.pManager, &sess_opts_.pConfig));
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

View File

@ -16,13 +16,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
jfieldID fid;
fid = env->GetFieldID(cls, "model",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;");
"Lcom/k2fsa/sherpa/mnn/OfflineTtsModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
// vits
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
"Lcom/k2fsa/sherpa/mnn/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
jclass vits_cls = env->GetObjectClass(vits);
@ -67,7 +67,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
// matcha
fid = env->GetFieldID(model_config_cls, "matcha",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
"Lcom/k2fsa/sherpa/mnn/OfflineTtsMatchaModelConfig;");
jobject matcha = env->GetObjectField(model, fid);
jclass matcha_cls = env->GetObjectClass(matcha);
@ -115,7 +115,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
// kokoro
fid = env->GetFieldID(model_config_cls, "kokoro",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;");
"Lcom/k2fsa/sherpa/mnn/OfflineTtsKokoroModelConfig;");
jobject kokoro = env->GetObjectField(model, fid);
jclass kokoro_cls = env->GetObjectClass(kokoro);

View File

@ -5,7 +5,7 @@ set(ROOT ${CMAKE_CURRENT_LIST_DIR}/../3rd_party/)
set(ONEDNN_DIR ${ROOT}/oneDNN/)
set(MNN_BUILD_DIR ${CMAKE_CURRENT_LIST_DIR}/../build/)
set(CONFIGURE_CMD cd ${ONEDNN_DIR} && cmake -DCMAKE_INSTALL_PREFIX=${MNN_BUILD_DIR} -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF -DDNNL_CPU_RUNTIME=SEQ)
set(CONFIGURE_CMD cd ${ONEDNN_DIR} && cmake -DCMAKE_INSTALL_PREFIX=${MNN_BUILD_DIR} -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF -DDNNL_CPU_RUNTIME=SEQ -DCMAKE_POLICY_VERSION_MINIMUM=3.10)
set(BUILD_CMD cd ${ONEDNN_DIR} && make -j8)
set(INSTALL_CMD cd ${ONEDNN_DIR} && make install)

View File

@ -2655,18 +2655,24 @@ bool VerifyOpParameterVector(flatbuffers::Verifier &verifier, const flatbuffers:
enum ForwardType {
ForwardType_CPU = 0,
ForwardType_METAL = 1,
ForwardType_OPENCL = 2,
ForwardType_OPENGLES = 3,
ForwardType_VULKAN = 4,
ForwardType_CUDA = 2,
ForwardType_OPENCL = 3,
ForwardType_AUTO = 4,
ForwardType_NNAPI = 5,
ForwardType_OPENGLES = 6,
ForwardType_VULKAN = 7,
ForwardType_MIN = ForwardType_CPU,
ForwardType_MAX = ForwardType_VULKAN
};
inline const ForwardType (&EnumValuesForwardType())[5] {
inline const ForwardType (&EnumValuesForwardType())[8] {
static const ForwardType values[] = {
ForwardType_CPU,
ForwardType_METAL,
ForwardType_CUDA,
ForwardType_OPENCL,
ForwardType_AUTO,
ForwardType_NNAPI,
ForwardType_OPENGLES,
ForwardType_VULKAN
};
@ -2677,7 +2683,10 @@ inline const char * const *EnumNamesForwardType() {
static const char * const names[] = {
"CPU",
"METAL",
"CUDA",
"OPENCL",
"AUTO",
"NNAPI",
"OPENGLES",
"VULKAN",
nullptr
@ -8577,6 +8586,9 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
inline const flatbuffers::TypeTable *ForwardTypeTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_CHAR, 0, 0 },
{ flatbuffers::ET_CHAR, 0, 0 },
{ flatbuffers::ET_CHAR, 0, 0 },
{ flatbuffers::ET_CHAR, 0, 0 },
{ flatbuffers::ET_CHAR, 0, 0 },
{ flatbuffers::ET_CHAR, 0, 0 },
@ -8589,12 +8601,15 @@ inline const flatbuffers::TypeTable *ForwardTypeTypeTable() {
static const char * const names[] = {
"CPU",
"METAL",
"CUDA",
"OPENCL",
"AUTO",
"NNAPI",
"OPENGLES",
"VULKAN"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_ENUM, 5, type_codes, type_refs, nullptr, names
flatbuffers::ST_ENUM, 8, type_codes, type_refs, nullptr, names
};
return &tt;
}

View File

@ -458,7 +458,10 @@ table TensorDescribe {
enum ForwardType : byte {
CPU = 0,
METAL,
CUDA,
OPENCL,
AUTO,
NNAPI,
OPENGLES,
VULKAN,
}

View File

@ -278,7 +278,7 @@ public:
const MNN::Op* op, Backend* backend) const override {
auto convOp = op->main_as_Convolution2D();
#ifdef MNN_USE_ONEDNN
return OneDNNConvInt8::create(backend, convOp, inputs, outputs);
return OneDNNConvInt8::create(backend, op, inputs, outputs);
#endif
auto core = static_cast<CPUBackend*>(backend)->functions();

View File

@ -14,9 +14,10 @@ OneDNNConvInt8::~OneDNNConvInt8() {
// Do nothing
}
Execution* OneDNNConvInt8::create(Backend* backend, const MNN::Convolution2D* convParam, const std::vector<Tensor*>& inputs, const std::vector<Tensor *> &outputs) {
Execution* OneDNNConvInt8::create(Backend* backend, const MNN::Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor *> &outputs) {
std::shared_ptr<OneDNNConvInt8::Resource> resource(new OneDNNConvInt8::Resource);
resource->backend = backend;
const auto convParam = op->main_as_Convolution2D();
const auto convCommon = convParam->common();
const auto kw = convCommon->kernelX();
const auto kh = convCommon->kernelY();
@ -68,7 +69,7 @@ Execution* OneDNNConvInt8::create(Backend* backend, const MNN::Convolution2D* co
}
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
if (convParam->quanParameter() != nullptr) {
quanCommon = ConvolutionCommon::load(convParam, backend, false);
quanCommon = ConvolutionCommon::load(op, backend, false);
weightSrc = quanCommon->weight.get();
}
auto user_weights = memory(user_weights_md, eng, (int8_t*)weightSrc);

View File

@ -20,7 +20,7 @@ public:
primitive_attr conv_attr;
engine eng;
};
static Execution* create(Backend *backend, const MNN::Convolution2D *convOp, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
static Execution* create(Backend *backend, const MNN::Op* op, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
OneDNNConvInt8(std::shared_ptr<OneDNNConvInt8::Resource> resource, const MNN::Convolution2DCommon* common, Backend* bn);
virtual ~OneDNNConvInt8();
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;