add numpy support for const and fix refcount

This commit is contained in:
如幻 2020-05-19 15:59:38 +08:00 committed by xiaying
parent 9f4f6c091d
commit f44cac33e2
1 changed files with 22 additions and 3 deletions

View File

@ -2054,8 +2054,6 @@ MOD_INIT(_mnncengine)
};
write(obj, dtype, total_length);
(*self)->unMap();
Py_XDECREF(obj);
});
// Load And Save
expr_module.def("load_as_list",
@ -2226,6 +2224,28 @@ MOD_INIT(_mnncengine)
}
PyObject *obj = value.ptr();
auto write = [](PyObject *obj, DType dtype, int64_t total_length) {
#ifndef USE_PRIVATE
if(PyArray_Check(obj)) {
//numpy support
if(total_length != PyArray_Size(obj)) {
throw std::runtime_error("data size does not match each other");
}
int npy_type = PyArray_TYPE(obj);
int itemsize = getitemsize(dtype, npy_type);
PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj);
auto tmpBuffer = PyArray_DATA(obj_cont);
if(NULL == tmpBuffer) {
throw std::runtime_error("numpy failed to get buffer");
}
auto data = malloc(total_length * itemsize);
if (nullptr == data) {
throw std::runtime_error("call to writeMap meet a error");
}
memcpy(data, tmpBuffer, total_length * itemsize);
Py_XDECREF(obj_cont);
return data;
}
#endif
INTS shapeData = getshape(obj);
int64_t totalLengthData = 1;
INTS stride;
@ -2280,7 +2300,6 @@ MOD_INIT(_mnncengine)
ret = _Const((const void*)data, shape, data_format, dtype2htype(dtype));
free(data);
}
Py_XDECREF(obj);
return ret;
},py::arg("value_list"), py::arg("shape"), py::arg("data_format")=NCHW, py::arg("dtype")=DType::DType_FLOAT);
INTS default_stride = {1, 1};