mirror of https://github.com/alibaba/MNN.git
update pymnn set_hint with seq
This commit is contained in:
parent
80a917a1b0
commit
bd59e631c2
|
|
@ -824,15 +824,37 @@ static PyObject* PyMNNInterpreter_setSessionMode(PyMNNInterpreter *self, PyObjec
|
|||
}
|
||||
static PyObject* PyMNNInterpreter_setSessionHint(PyMNNInterpreter *self, PyObject *args) {
|
||||
int type_val = 0;
|
||||
int num_val = 0;
|
||||
if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) {
|
||||
PyObject* num_val = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) {
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"PyMNNInterpreter_setSessionHint: Not interger input and interger input");
|
||||
return NULL;
|
||||
"PyMNNInterpreter_setSessionHint: Not interger input and interger/list/tuple input");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto type = (MNN::Interpreter::HintMode)type_val;
|
||||
self->interpreter->setSessionHint(type, num_val);
|
||||
if (PyList_Check(num_val)) {
|
||||
size_t size = PyList_Size(num_val);
|
||||
int* list = new int[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
list[i] = static_cast<int>(PyLong_AsLong(PyList_GetItem(num_val, i)));
|
||||
}
|
||||
self->interpreter->setSessionHint(type, list, size);
|
||||
delete[] list;
|
||||
} else if (PyTuple_Check(num_val)) {
|
||||
size_t size = PyTuple_Size(num_val);
|
||||
int* list = new int[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
list[i] = static_cast<int>(PyLong_AsLong(PyTuple_GetItem(num_val, i)));
|
||||
}
|
||||
self->interpreter->setSessionHint(type, list, size);
|
||||
delete[] list;
|
||||
} else if (PyLong_Check(num_val)) {
|
||||
self->interpreter->setSessionHint(type, static_cast<int>(PyLong_AsLong(num_val)));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"PyMNNInterpreter_setSessionHint: num_val must be a list, tuple or int");
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) {
|
||||
|
|
|
|||
|
|
@ -345,15 +345,37 @@ static PyObject* PyMNNRuntimeManager_set_mode(PyMNNRuntimeManager *self, PyObjec
|
|||
}
|
||||
static PyObject* PyMNNRuntimeManager_set_hint(PyMNNRuntimeManager *self, PyObject *args) {
|
||||
int type_val = 0;
|
||||
int num_val = 0;
|
||||
if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) {
|
||||
PyObject* num_val = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) {
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"PyMNNRuntimeManager_set_hint: Not interger input and interger input");
|
||||
return NULL;
|
||||
"PyMNNRuntimeManager_set_hint: Not interger input and interger/list/tuple input");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto type = (MNN::Interpreter::HintMode)type_val;
|
||||
(*(self->ptr))->setHint(type, num_val);
|
||||
if (PyList_Check(num_val)) {
|
||||
size_t size = PyList_Size(num_val);
|
||||
int* list = new int[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
list[i] = static_cast<int>(PyLong_AsLong(PyList_GetItem(num_val, i)));
|
||||
}
|
||||
(*(self->ptr))->setHint(type, list, size);
|
||||
delete[] list;
|
||||
} else if (PyTuple_Check(num_val)) {
|
||||
size_t size = PyTuple_Size(num_val);
|
||||
int* list = new int[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
list[i] = static_cast<int>(PyLong_AsLong(PyTuple_GetItem(num_val, i)));
|
||||
}
|
||||
(*(self->ptr))->setHint(type, list, size);
|
||||
delete[] list;
|
||||
} else if (PyLong_Check(num_val)) {
|
||||
(*(self->ptr))->setHint(type, static_cast<int>(PyLong_AsLong(num_val)));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"PyMNNRuntimeManager_set_hint: num_val must be a list, tuple or int");
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue