update pymnn set_hint with seq

This commit is contained in:
Jules 2025-07-21 09:44:30 +00:00
parent 80a917a1b0
commit bd59e631c2
2 changed files with 54 additions and 10 deletions

View File

@ -824,15 +824,37 @@ static PyObject* PyMNNInterpreter_setSessionMode(PyMNNInterpreter *self, PyObjec
}
static PyObject* PyMNNInterpreter_setSessionHint(PyMNNInterpreter *self, PyObject *args) {
int type_val = 0;
int num_val = 0;
if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) {
PyObject* num_val = nullptr;
if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) {
PyErr_SetString(PyExc_Exception,
"PyMNNInterpreter_setSessionHint: Not interger input and interger input");
return NULL;
"PyMNNInterpreter_setSessionHint: Not interger input and interger/list/tuple input");
return nullptr;
}
auto type = (MNN::Interpreter::HintMode)type_val;
self->interpreter->setSessionHint(type, num_val);
if (PyList_Check(num_val)) {
size_t size = PyList_Size(num_val);
int* list = new int[size];
for (int i = 0; i < size; i++) {
list[i] = static_cast<int>(PyLong_AsLong(PyList_GetItem(num_val, i)));
}
self->interpreter->setSessionHint(type, list, size);
delete[] list;
} else if (PyTuple_Check(num_val)) {
size_t size = PyTuple_Size(num_val);
int* list = new int[size];
for (int i = 0; i < size; i++) {
list[i] = static_cast<int>(PyLong_AsLong(PyTuple_GetItem(num_val, i)));
}
self->interpreter->setSessionHint(type, list, size);
delete[] list;
} else if (PyLong_Check(num_val)) {
self->interpreter->setSessionHint(type, static_cast<int>(PyLong_AsLong(num_val)));
} else {
PyErr_SetString(PyExc_Exception,
"PyMNNInterpreter_setSessionHint: num_val must be a list, tuple or int");
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) {

View File

@ -345,15 +345,37 @@ static PyObject* PyMNNRuntimeManager_set_mode(PyMNNRuntimeManager *self, PyObjec
}
static PyObject* PyMNNRuntimeManager_set_hint(PyMNNRuntimeManager *self, PyObject *args) {
int type_val = 0;
int num_val = 0;
if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) {
PyObject* num_val = nullptr;
if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) {
PyErr_SetString(PyExc_Exception,
"PyMNNRuntimeManager_set_hint: Not interger input and interger input");
return NULL;
"PyMNNRuntimeManager_set_hint: Not interger input and interger/list/tuple input");
return nullptr;
}
auto type = (MNN::Interpreter::HintMode)type_val;
(*(self->ptr))->setHint(type, num_val);
if (PyList_Check(num_val)) {
size_t size = PyList_Size(num_val);
int* list = new int[size];
for (int i = 0; i < size; i++) {
list[i] = static_cast<int>(PyLong_AsLong(PyList_GetItem(num_val, i)));
}
(*(self->ptr))->setHint(type, list, size);
delete[] list;
} else if (PyTuple_Check(num_val)) {
size_t size = PyTuple_Size(num_val);
int* list = new int[size];
for (int i = 0; i < size; i++) {
list[i] = static_cast<int>(PyLong_AsLong(PyTuple_GetItem(num_val, i)));
}
(*(self->ptr))->setHint(type, list, size);
delete[] list;
} else if (PyLong_Check(num_val)) {
(*(self->ptr))->setHint(type, static_cast<int>(PyLong_AsLong(num_val)));
} else {
PyErr_SetString(PyExc_Exception,
"PyMNNRuntimeManager_set_hint: num_val must be a list, tuple or int");
return nullptr;
}
Py_RETURN_NONE;
}