diff --git a/pymnn/src/MNN.cc b/pymnn/src/MNN.cc index 71dc156b..849180e9 100644 --- a/pymnn/src/MNN.cc +++ b/pymnn/src/MNN.cc @@ -2054,8 +2054,6 @@ MOD_INIT(_mnncengine) }; write(obj, dtype, total_length); (*self)->unMap(); - Py_XDECREF(obj); - }); // Load And Save expr_module.def("load_as_list", @@ -2226,6 +2224,28 @@ MOD_INIT(_mnncengine) } PyObject *obj = value.ptr(); auto write = [](PyObject *obj, DType dtype, int64_t total_length) { + #ifndef USE_PRIVATE + if(PyArray_Check(obj)) { + //numpy support + if(total_length != PyArray_Size(obj)) { + throw std::runtime_error("data size does not match each other"); + } + int npy_type = PyArray_TYPE(obj); + int itemsize = getitemsize(dtype, npy_type); + PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj); + auto tmpBuffer = PyArray_DATA(obj_cont); + if(NULL == tmpBuffer) { + throw std::runtime_error("numpy failed to get buffer"); + } + auto data = malloc(total_length * itemsize); + if (nullptr == data) { + throw std::runtime_error("call to writeMap meet a error"); + } + memcpy(data, tmpBuffer, total_length * itemsize); + Py_XDECREF(obj_cont); + return data; + } +#endif INTS shapeData = getshape(obj); int64_t totalLengthData = 1; INTS stride; @@ -2280,7 +2300,6 @@ MOD_INIT(_mnncengine) ret = _Const((const void*)data, shape, data_format, dtype2htype(dtype)); free(data); } - Py_XDECREF(obj); return ret; },py::arg("value_list"), py::arg("shape"), py::arg("data_format")=NCHW, py::arg("dtype")=DType::DType_FLOAT); INTS default_stride = {1, 1};