mirror of https://github.com/alibaba/MNN.git
Compare commits
4 Commits
61fcaf4472
...
034d115bc9
| Author | SHA1 | Date |
|---|---|---|
|
|
034d115bc9 | |
|
|
eb35c52260 | |
|
|
fa3a03d6e3 | |
|
|
30b463fe52 |
|
|
@ -181,7 +181,7 @@ class ModelDownloadManager private constructor(private val context: Context) {
|
|||
if (getDownloadedFile(modelId) != null) {
|
||||
downloadInfo.downlodaState = DownloadInfo.DownloadSate.COMPLETED
|
||||
downloadInfo.progress = 1.0
|
||||
} else if (getDownloadSizeTotal(ApplicationProvider.get(), modelId) > 0) {
|
||||
} else if (getDownloadSizeSaved(ApplicationProvider.get(), modelId) > 0) {
|
||||
val totalSize = getDownloadSizeTotal(ApplicationProvider.get(), modelId)
|
||||
val savedSize = getDownloadSizeSaved(ApplicationProvider.get(), modelId)
|
||||
downloadInfo.totalSize = totalSize
|
||||
|
|
|
|||
|
|
@ -27,16 +27,6 @@ namespace OpenCL {
|
|||
void registerOpenCLOps();
|
||||
#endif
|
||||
|
||||
std::mutex CLRuntime::globalRuntimeLock;
|
||||
std::weak_ptr<OpenCLRuntime> CLRuntime::globalRuntime;
|
||||
void CLRuntime::setGlobalCLRuntime(std::shared_ptr<OpenCLRuntime> runtime){
|
||||
std::lock_guard<std::mutex> _l(globalRuntimeLock);
|
||||
globalRuntime = runtime;
|
||||
}
|
||||
std::shared_ptr<OpenCLRuntime> CLRuntime::getGlobalCLRuntime(){
|
||||
auto sharedPtr = globalRuntime.lock();
|
||||
return sharedPtr;
|
||||
}
|
||||
|
||||
CLRuntime::CLRuntime(const Backend::Info& info){
|
||||
mInfo = info;
|
||||
|
|
@ -44,7 +34,6 @@ CLRuntime::CLRuntime(const Backend::Info& info){
|
|||
int device_id = 0;
|
||||
int platform_size = 0;
|
||||
void *context_ptr = nullptr;
|
||||
auto globalRuntimePtr = getGlobalCLRuntime();
|
||||
if (nullptr != info.user) {
|
||||
if (info.user->sharedContext != nullptr) {
|
||||
platform_id = ((MNNDeviceContext*)info.user->sharedContext)->platformId;
|
||||
|
|
@ -59,12 +48,7 @@ CLRuntime::CLRuntime(const Backend::Info& info){
|
|||
mMemory = mInfo.user->memory;
|
||||
}
|
||||
|
||||
if(globalRuntimePtr && globalRuntimePtr.get()->canShareRuntime(platform_size, platform_id, device_id, context_ptr)){
|
||||
mOpenCLRuntime = globalRuntimePtr;
|
||||
}else{
|
||||
mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint()));
|
||||
setGlobalCLRuntime(mOpenCLRuntime);
|
||||
}
|
||||
|
||||
//Whether runtimeError
|
||||
mCLRuntimeError = mOpenCLRuntime->isCreateError();
|
||||
|
|
|
|||
|
|
@ -62,8 +62,6 @@ public:
|
|||
void convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor, MNN_DATA_FORMAT data_format, int precision, int backend_memtype, bool svmFlag = false, int memtype = MNN_FORWARD_CPU) const;
|
||||
void convertFromDevice(const Tensor* srcTensor, const Tensor* dstTensor, MNN_DATA_FORMAT data_format, int precision, int backend_memtype, bool svmFlag = false, int memtype = MNN_FORWARD_CPU) const;
|
||||
void copyBetweenDevice(const Tensor* srcTensor, const Tensor* dstTensor, int precision, int backend_memtype) const;
|
||||
static void setGlobalCLRuntime(std::shared_ptr<OpenCLRuntime> runtime);
|
||||
static std::shared_ptr<OpenCLRuntime> getGlobalCLRuntime();
|
||||
|
||||
private:
|
||||
Backend::Info mInfo;
|
||||
|
|
@ -76,8 +74,6 @@ private:
|
|||
|
||||
friend class OpenCLBackend;
|
||||
TuneInfo* mTunedInfo;
|
||||
static std::weak_ptr<OpenCLRuntime> globalRuntime;
|
||||
static std::mutex globalRuntimeLock;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,16 @@ namespace MNN {
|
|||
|
||||
extern const std::map<std::string, const char*> OpenCLProgramMap;
|
||||
static std::mutex gCLMutex;
|
||||
static std::weak_ptr<::cl::Context> globalContext;
|
||||
static std::mutex gCLContextMutex;
|
||||
static std::shared_ptr<::cl::Context> getGlobalContext(){
|
||||
return globalContext.lock();
|
||||
}
|
||||
|
||||
static void setGlobalContext(std::shared_ptr<cl::Context> Context){
|
||||
std::lock_guard<std::mutex> lck(gCLContextMutex);
|
||||
globalContext = Context;
|
||||
}
|
||||
|
||||
bool OpenCLRuntime::getDeviceSupportsExtension(const cl::Device &device, const char *extensionName) {
|
||||
std::string extensions = device.getInfo<CL_DEVICE_EXTENSIONS>();
|
||||
|
|
@ -98,7 +108,6 @@ OpenCLRuntime::OpenCLRuntime(int platformSize, int platformId, int deviceId, voi
|
|||
#ifdef ENABLE_OPENCL_TIME_PROFILER
|
||||
properties |= CL_QUEUE_PROFILING_ENABLE;
|
||||
#endif
|
||||
cl_int res;
|
||||
// if device is QUALCOMM's and version is 2.0 , set spacial optimized param
|
||||
|
||||
sscanf(deviceVersion.c_str(), "%*s%f%*s", &mCLVersion);
|
||||
|
|
@ -200,11 +209,15 @@ OpenCLRuntime::OpenCLRuntime(int platformSize, int platformId, int deviceId, voi
|
|||
// Do nothing
|
||||
});
|
||||
}else{
|
||||
if(context_properties.size() > 0){
|
||||
context_properties.push_back(0);
|
||||
mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res));
|
||||
}else{
|
||||
mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), nullptr, nullptr, nullptr, &res));
|
||||
mContext = getGlobalContext();
|
||||
if(mContext == nullptr){
|
||||
if(context_properties.size() > 0){
|
||||
context_properties.push_back(0);
|
||||
mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res));
|
||||
}else{
|
||||
mContext = std::shared_ptr<cl::Context>(new cl::Context(std::vector<cl::Device>({*mFirstGPUDevicePtr}), nullptr, nullptr, nullptr, &res));
|
||||
}
|
||||
setGlobalContext(mContext);
|
||||
}
|
||||
}
|
||||
MNN_CHECK_CL_SUCCESS(res, "context");
|
||||
|
|
|
|||
|
|
@ -175,9 +175,6 @@ public:
|
|||
return mFlops;
|
||||
}
|
||||
|
||||
bool canShareRuntime(int platformSize, int platformId, int deviceId, void *contextPtr){
|
||||
return (platformSize == mInitInfo.platformSize) && (platformId == mInitInfo.platformId) && (deviceId == mInitInfo.deviceId) && (contextPtr == mInitInfo.contextPtr);
|
||||
}
|
||||
|
||||
double getCostTime(const cl::Event *event);
|
||||
double getQueuedTime(const cl::Event *event);
|
||||
|
|
|
|||
Loading…
Reference in New Issue