This commit is contained in:
如幻 2020-02-03 13:36:32 +08:00 committed by xiaying
parent cc21dcec1d
commit 9cbdfcb6f0
3 changed files with 490 additions and 1021 deletions

View File

@ -90,7 +90,7 @@ def configure_extension_build():
engine_link_args = []
engine_sources = [os.path.join(root_dir, "pymnn", "src", "MNN.cc")]
engine_include_dirs = [os.path.join(root_dir, "include")]
engine_include_dirs = [os.path.join(root_dir, "include")]
engine_include_dirs += [os.path.join(root_dir, "express")]
engine_include_dirs += [os.path.join(root_dir, "source")]
engine_include_dirs += [os.path.join(root_dir, "source", "core")]
engine_include_dirs += [os.path.join(root_dir, "schema", "current")]

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,8 @@
#pragma once
#include <string>
using namespace MNN;
using namespace MNN::Express;
using namespace std;
// Returns true if obj is a bytes/str or unicode object
inline bool checkString(PyObject* obj) {
return PyBytes_Check(obj) || PyUnicode_Check(obj);
@ -35,4 +38,63 @@ inline PyObject* string2Object(const std::string& str) {
return PyUnicode_FromString(str.c_str());
#endif
}
inline double unpackDouble(PyObject* obj) {
if (PyFloat_Check(obj)) {
return PyFloat_AS_DOUBLE(obj);
}
throw std::runtime_error("Overflow when unpacking double");
}
inline int64_t unpackLong(PyObject* obj) {
int overflow;
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
if (value == -1 && PyErr_Occurred()) {
throw std::exception();
}
if (overflow != 0) {
throw std::runtime_error("Overflow when unpacking long");
}
return (int64_t)value;
}
inline void store_scalar(void* data, DataType dtype, PyObject* obj) {
switch (dtype) {
case DataType_DT_UINT8: *(uint8_t*)data = (uint8_t)unpackLong(obj); break;
case DataType_DT_INT32: *(int32_t*)data = (int32_t)unpackLong(obj); break;
case DataType_DT_INT64: *(int64_t*)data = unpackLong(obj); break;
case DataType_DT_FLOAT: *(float*)data = (float)unpackDouble(obj); break;
case DataType_DT_DOUBLE: *(double*)data = (double)unpackDouble(obj); break;
default: throw std::runtime_error("invalid type");
}
}
INTS getshape(PyObject* seq) {
INTS shape;
while (PySequence_Check(seq)) {
auto length = PySequence_Length(seq);
if (length < 0) throw std::exception();
shape.push_back(length);
if (shape.size() > 20) {
throw std::exception();
}
if (length == 0) break;
seq = PySequence_GetItem(seq,0);
}
return shape;
}
void recursive_store(char* data, INTS shape, INTS stride, int dim, PyObject* obj, DataType dtype, int elementSize) {
auto ndim = shape.size();
if(dim == ndim) {
store_scalar(data, dtype, obj);
return;
}
auto n = shape[dim];
auto seq = PySequence_Fast(obj, "not a sequence");
if (!seq) throw std::exception();
auto seq_size = PySequence_Fast_GET_SIZE(seq);
if (seq_size != n) {
throw std::exception();
}
PyObject** items = PySequence_Fast_ITEMS(seq);
for (int i = 0; i < n; i++) {
recursive_store(data, shape, stride, dim + 1, items[i], dtype, elementSize);
data += stride[dim] * elementSize;
}
}