mirror of https://github.com/alibaba/MNN.git
Compare commits
4 Commits
c41fa761dc
...
906bf697a2
| Author | SHA1 | Date |
|---|---|---|
|
|
906bf697a2 | |
|
|
6ed0ea9bf3 | |
|
|
bd59e631c2 | |
|
|
d0d879384a |
|
|
@ -824,15 +824,37 @@ static PyObject* PyMNNInterpreter_setSessionMode(PyMNNInterpreter *self, PyObjec
|
||||||
}
|
}
|
||||||
static PyObject* PyMNNInterpreter_setSessionHint(PyMNNInterpreter *self, PyObject *args) {
|
static PyObject* PyMNNInterpreter_setSessionHint(PyMNNInterpreter *self, PyObject *args) {
|
||||||
int type_val = 0;
|
int type_val = 0;
|
||||||
int num_val = 0;
|
PyObject* num_val = nullptr;
|
||||||
if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) {
|
if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) {
|
||||||
PyErr_SetString(PyExc_Exception,
|
PyErr_SetString(PyExc_Exception,
|
||||||
"PyMNNInterpreter_setSessionHint: Not interger input and interger input");
|
"PyMNNInterpreter_setSessionHint: Not interger input and interger/list/tuple input");
|
||||||
return NULL;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type = (MNN::Interpreter::HintMode)type_val;
|
auto type = (MNN::Interpreter::HintMode)type_val;
|
||||||
self->interpreter->setSessionHint(type, num_val);
|
if (PyList_Check(num_val)) {
|
||||||
|
size_t size = PyList_Size(num_val);
|
||||||
|
int* list = new int[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
list[i] = static_cast<int>(PyLong_AsLong(PyList_GetItem(num_val, i)));
|
||||||
|
}
|
||||||
|
self->interpreter->setSessionHint(type, list, size);
|
||||||
|
delete[] list;
|
||||||
|
} else if (PyTuple_Check(num_val)) {
|
||||||
|
size_t size = PyTuple_Size(num_val);
|
||||||
|
int* list = new int[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
list[i] = static_cast<int>(PyLong_AsLong(PyTuple_GetItem(num_val, i)));
|
||||||
|
}
|
||||||
|
self->interpreter->setSessionHint(type, list, size);
|
||||||
|
delete[] list;
|
||||||
|
} else if (PyLong_Check(num_val)) {
|
||||||
|
self->interpreter->setSessionHint(type, static_cast<int>(PyLong_AsLong(num_val)));
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_Exception,
|
||||||
|
"PyMNNInterpreter_setSessionHint: num_val must be a list, tuple or int");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) {
|
static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) {
|
||||||
|
|
|
||||||
|
|
@ -345,15 +345,37 @@ static PyObject* PyMNNRuntimeManager_set_mode(PyMNNRuntimeManager *self, PyObjec
|
||||||
}
|
}
|
||||||
static PyObject* PyMNNRuntimeManager_set_hint(PyMNNRuntimeManager *self, PyObject *args) {
|
static PyObject* PyMNNRuntimeManager_set_hint(PyMNNRuntimeManager *self, PyObject *args) {
|
||||||
int type_val = 0;
|
int type_val = 0;
|
||||||
int num_val = 0;
|
PyObject* num_val = nullptr;
|
||||||
if (!PyArg_ParseTuple(args, "ii", &type_val, &num_val)) {
|
if (!PyArg_ParseTuple(args, "iO", &type_val, &num_val)) {
|
||||||
PyErr_SetString(PyExc_Exception,
|
PyErr_SetString(PyExc_Exception,
|
||||||
"PyMNNRuntimeManager_set_hint: Not interger input and interger input");
|
"PyMNNRuntimeManager_set_hint: Not interger input and interger/list/tuple input");
|
||||||
return NULL;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type = (MNN::Interpreter::HintMode)type_val;
|
auto type = (MNN::Interpreter::HintMode)type_val;
|
||||||
(*(self->ptr))->setHint(type, num_val);
|
if (PyList_Check(num_val)) {
|
||||||
|
size_t size = PyList_Size(num_val);
|
||||||
|
int* list = new int[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
list[i] = static_cast<int>(PyLong_AsLong(PyList_GetItem(num_val, i)));
|
||||||
|
}
|
||||||
|
(*(self->ptr))->setHint(type, list, size);
|
||||||
|
delete[] list;
|
||||||
|
} else if (PyTuple_Check(num_val)) {
|
||||||
|
size_t size = PyTuple_Size(num_val);
|
||||||
|
int* list = new int[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
list[i] = static_cast<int>(PyLong_AsLong(PyTuple_GetItem(num_val, i)));
|
||||||
|
}
|
||||||
|
(*(self->ptr))->setHint(type, list, size);
|
||||||
|
delete[] list;
|
||||||
|
} else if (PyLong_Check(num_val)) {
|
||||||
|
(*(self->ptr))->setHint(type, static_cast<int>(PyLong_AsLong(num_val)));
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_Exception,
|
||||||
|
"PyMNNRuntimeManager_set_hint: num_val must be a list, tuple or int");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,20 @@ static inline void MNN__mm_storeu_si64(void* add, __m128i value) {
|
||||||
_mm_storeu_ps(temp, _mm_castsi128_ps(value));
|
_mm_storeu_ps(temp, _mm_castsi128_ps(value));
|
||||||
::memcpy(add, temp, sizeof(int64_t));
|
::memcpy(add, temp, sizeof(int64_t));
|
||||||
}
|
}
|
||||||
|
#if defined(_MSC_VER) && !defined(_mm256_extract_epi64)
|
||||||
|
static inline uint64_t _mm256_extract_epi64(__m256i a, const int index)
|
||||||
|
{
|
||||||
|
typedef union {
|
||||||
|
__m256i v;
|
||||||
|
uint64_t i64[4];
|
||||||
|
} extractor;
|
||||||
|
|
||||||
|
extractor u;
|
||||||
|
u.v = a;
|
||||||
|
|
||||||
|
return u.i64[index];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define POSTTREAT(N) \
|
#define POSTTREAT(N) \
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue