mirror of https://github.com/alibaba/MNN.git
				
				
				
			add numpy support for const and fix refcount
This commit is contained in:
		
							parent
							
								
									9f4f6c091d
								
							
						
					
					
						commit
						f44cac33e2
					
				| 
						 | 
				
			
			@ -2054,8 +2054,6 @@ MOD_INIT(_mnncengine)
 | 
			
		|||
                };
 | 
			
		||||
                write(obj, dtype, total_length);
 | 
			
		||||
                (*self)->unMap();
 | 
			
		||||
                Py_XDECREF(obj);
 | 
			
		||||
 | 
			
		||||
            });
 | 
			
		||||
    // Load And Save
 | 
			
		||||
    expr_module.def("load_as_list",
 | 
			
		||||
| 
						 | 
				
			
			@ -2226,6 +2224,28 @@ MOD_INIT(_mnncengine)
 | 
			
		|||
                }
 | 
			
		||||
                PyObject *obj = value.ptr();
 | 
			
		||||
                auto write = [](PyObject *obj, DType dtype, int64_t total_length) {
 | 
			
		||||
 #ifndef USE_PRIVATE
 | 
			
		||||
                    if(PyArray_Check(obj)) {
 | 
			
		||||
                        //numpy support
 | 
			
		||||
                        if(total_length != PyArray_Size(obj)) {
 | 
			
		||||
                            throw std::runtime_error("data size does not match each other");
 | 
			
		||||
                        }
 | 
			
		||||
                        int npy_type = PyArray_TYPE(obj);
 | 
			
		||||
                        int itemsize = getitemsize(dtype, npy_type);
 | 
			
		||||
                        PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj);
 | 
			
		||||
                        auto tmpBuffer = PyArray_DATA(obj_cont);
 | 
			
		||||
                        if(NULL == tmpBuffer) {
 | 
			
		||||
                            throw std::runtime_error("numpy failed to get buffer");
 | 
			
		||||
                        }
 | 
			
		||||
                        auto data = malloc(total_length * itemsize);
 | 
			
		||||
                        if (nullptr == data) {
 | 
			
		||||
                            throw std::runtime_error("call to writeMap meet a error");
 | 
			
		||||
                        }
 | 
			
		||||
                        memcpy(data, tmpBuffer, total_length * itemsize);
 | 
			
		||||
                        Py_XDECREF(obj_cont);
 | 
			
		||||
                        return data;
 | 
			
		||||
                    }
 | 
			
		||||
#endif
 | 
			
		||||
                    INTS shapeData = getshape(obj);
 | 
			
		||||
                    int64_t totalLengthData = 1;
 | 
			
		||||
                    INTS stride;
 | 
			
		||||
| 
						 | 
				
			
			@ -2280,7 +2300,6 @@ MOD_INIT(_mnncengine)
 | 
			
		|||
                    ret = _Const((const void*)data, shape, data_format, dtype2htype(dtype));
 | 
			
		||||
                    free(data);
 | 
			
		||||
                }
 | 
			
		||||
                Py_XDECREF(obj);
 | 
			
		||||
                return ret;
 | 
			
		||||
            },py::arg("value_list"), py::arg("shape"), py::arg("data_format")=NCHW, py::arg("dtype")=DType::DType_FLOAT);
 | 
			
		||||
    INTS default_stride = {1, 1};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue