From bd59e631c2bf304c04bb4f4759eb63fd975cd77e Mon Sep 17 00:00:00 2001 From: Jules <16029431+jules-ai@users.noreply.github.com> Date: Mon, 21 Jul 2025 09:44:30 +0000 Subject: [PATCH] update pymnn set_hint with seq --- pymnn/src/MNN.cc | 32 +++++++++++++++++++++++++++----- pymnn/src/nn.h | 32 +++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/pymnn/src/MNN.cc b/pymnn/src/MNN.cc index 93024466..7291dc67 100644 --- a/pymnn/src/MNN.cc +++ b/pymnn/src/MNN.cc @@ -824,15 +824,37 @@ static PyObject* PyMNNInterpreter_setSessionMode(PyMNNInterpreter *self, PyObjec } static PyObject* PyMNNInterpreter_setSessionHint(PyMNNInterpreter *self, PyObject *args) { int type_val = 0; - int num_val = 0; - if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) { + PyObject* num_val = nullptr; + if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) { PyErr_SetString(PyExc_Exception, - "PyMNNInterpreter_setSessionHint: Not interger input and interger input"); - return NULL; + "PyMNNInterpreter_setSessionHint: Not interger input and interger/list/tuple input"); + return nullptr; } auto type = (MNN::Interpreter::HintMode)type_val; - self->interpreter->setSessionHint(type, num_val); + if (PyList_Check(num_val)) { + size_t size = PyList_Size(num_val); + int* list = new int[size]; + for (int i = 0; i < size; i++) { + list[i] = static_cast(PyLong_AsLong(PyList_GetItem(num_val, i))); + } + self->interpreter->setSessionHint(type, list, size); + delete[] list; + } else if (PyTuple_Check(num_val)) { + size_t size = PyTuple_Size(num_val); + int* list = new int[size]; + for (int i = 0; i < size; i++) { + list[i] = static_cast(PyLong_AsLong(PyTuple_GetItem(num_val, i))); + } + self->interpreter->setSessionHint(type, list, size); + delete[] list; + } else if (PyLong_Check(num_val)) { + self->interpreter->setSessionHint(type, static_cast(PyLong_AsLong(num_val))); + } else { + PyErr_SetString(PyExc_Exception, + "PyMNNInterpreter_setSessionHint: num_val must be a list, tuple or int"); + return nullptr; + } Py_RETURN_NONE; } static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) { diff --git a/pymnn/src/nn.h b/pymnn/src/nn.h index a775cb0d..33c5dc71 100644 --- a/pymnn/src/nn.h +++ b/pymnn/src/nn.h @@ -345,15 +345,37 @@ static PyObject* PyMNNRuntimeManager_set_mode(PyMNNRuntimeManager *self, PyObjec } static PyObject* PyMNNRuntimeManager_set_hint(PyMNNRuntimeManager *self, PyObject *args) { int type_val = 0; - int num_val = 0; - if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) { + PyObject* num_val = nullptr; + if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) { PyErr_SetString(PyExc_Exception, - "PyMNNRuntimeManager_set_hint: Not interger input and interger input"); - return NULL; + "PyMNNRuntimeManager_set_hint: Not interger input and interger/list/tuple input"); + return nullptr; } auto type = (MNN::Interpreter::HintMode)type_val; - (*(self->ptr))->setHint(type, num_val); + if (PyList_Check(num_val)) { + size_t size = PyList_Size(num_val); + int* list = new int[size]; + for (int i = 0; i < size; i++) { + list[i] = static_cast(PyLong_AsLong(PyList_GetItem(num_val, i))); + } + (*(self->ptr))->setHint(type, list, size); + delete[] list; + } else if (PyTuple_Check(num_val)) { + size_t size = PyTuple_Size(num_val); + int* list = new int[size]; + for (int i = 0; i < size; i++) { + list[i] = static_cast(PyLong_AsLong(PyTuple_GetItem(num_val, i))); + } + (*(self->ptr))->setHint(type, list, size); + delete[] list; + } else if (PyLong_Check(num_val)) { + (*(self->ptr))->setHint(type, static_cast(PyLong_AsLong(num_val))); + } else { + PyErr_SetString(PyExc_Exception, + "PyMNNRuntimeManager_set_hint: num_val must be a list, tuple or int"); + return nullptr; + } Py_RETURN_NONE; }