MNN/source/backend/opencl/execution/image/CommonExecution.cpp

78 lines
2.7 KiB
C++

//
// CommonExecution.cpp
// MNN
//
// Created by MNN on 2019/02/28.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "backend/opencl/execution/image/CommonExecution.hpp"
namespace MNN {
namespace OpenCL {
CommonExecution::CommonExecution(Backend *backend, const MNN::Op *Op)
: Execution(backend), mOp(Op) {
mOpType = Op->type();
}
ErrorCode CommonExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs){
auto openCLBackend = static_cast<OpenCLBackend*>(backend());
auto runtime = openCLBackend->getOpenCLRuntime();
openCLBackend->startRecord(mRecording);
auto error = onEncode(inputs, outputs);
if(NO_ERROR != error){
return error;
}
for (auto &unit : mUnits) {
bool lws_null = true;
for (size_t i = 0; i < unit.globalWorkSize.dimensions(); ++i) {
unit.globalWorkSize.get()[i] = ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i]));
if(unit.localWorkSize.get()[i] != 0) {
lws_null = false;
}
}
if(lws_null){
unit.localWorkSize = cl::NullRange;
}
}
openCLBackend->endRecord(mRecording);
return NO_ERROR;
}
ErrorCode CommonExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto openCLBackend = static_cast<OpenCLBackend*>(backend());
auto runtime = openCLBackend->getOpenCLRuntime();
#ifdef ENABLE_OPENCL_TIME_PROFILER
int idx = 0;
#else
if(openCLBackend->isUseRecordQueue()){
openCLBackend->addRecord(mRecording, mOpRecordUpdateInfo);
return NO_ERROR;
}
#endif
auto res = CL_SUCCESS;
for (auto &unit : mUnits) {
#ifdef ENABLE_OPENCL_TIME_PROFILER
cl::Event event;
res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(),
cl::NullRange,
unit.globalWorkSize,
unit.localWorkSize,
nullptr,
&event);
runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event});
#else
res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(),
cl::NullRange,
unit.globalWorkSize,
unit.localWorkSize);
#endif
MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type()));
}
return NO_ERROR;
}
} // namespace OpenCL
}; // namespace MNN