MNN/source/backend/opencl/execution/buffer/SelectBufExecution.cpp

104 lines
3.7 KiB
C++

//
// SelectBufExecution.cpp
// MNN
//
// Created by MNN on 2023/08/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_OPENCL_BUFFER_CLOSED
#include "backend/opencl/execution/buffer/SelectBufExecution.hpp"
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
#include "backend/opencl/core/OpenCLBackend.hpp"
namespace MNN {
namespace OpenCL {
SelectBufExecution::SelectBufExecution(Backend* backend) : Execution(backend) {
// Do nothing
}
ErrorCode SelectBufExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
auto inSize1 = inputs[1]->elementSize();
auto inSize2 = inputs[2]->elementSize();
auto openCLBackend = static_cast<OpenCLBackend*>(backend());
auto runtime = openCLBackend->getOpenCLRuntime();
if(inSize1 == 1)
mBuildOptions.emplace("-DINSIZE1_EUQAL_1");
if(inSize2 == 1)
mBuildOptions.emplace("-DINSIZE2_EUQAL_1");
mKernel = runtime->buildKernel("select_buf", "select_buf", mBuildOptions);
mMaxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel));
std::vector<int> outputShape = tensorShapeFormat(outputs[0]);
int batch = outputShape.at(0);
int outputHeight = outputShape.at(1);
int outputWidth = outputShape.at(2);
int channels = outputShape.at(3);
int channelBlocks = (channels + 3) / 4;
int outSize = batch * channelBlocks * outputWidth * outputHeight * 4;
mGlobalWorkSize = {
static_cast<uint32_t>(outSize),
1
};
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]);
ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]);
ret |= mKernel.setArg(idx++, openCLBuffer(inputs[0]));
ret |= mKernel.setArg(idx++, openCLBuffer(inputs[1]));
ret |= mKernel.setArg(idx++, openCLBuffer(inputs[2]));
ret |= mKernel.setArg(idx++, openCLBuffer(outputs[0]));
MNN_CHECK_CL_SUCCESS(ret, "setArg SelectBufExecution");
std::string kernelName = "select_buf";
mLocalSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, mKernel).first;
return NO_ERROR;
}
ErrorCode SelectBufExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
#ifdef LOG_VERBOSE
MNN_PRINT("start SelectBufExecution onExecute...");
#endif
auto mOpenCLBackend = static_cast<OpenCLBackend*>(backend());
#ifdef ENABLE_OPENCL_TIME_PROFILER
cl::Event event;
runKernel2D(mKernel, mGlobalWorkSize, mLocalSize,
mOpenCLBackend->getOpenCLRuntime(), &event);
int costTime = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
MNN_PRINT("kernel cost:%d us Select\n",costTime);
#else
runKernel2D(mKernel, mGlobalWorkSize, mLocalSize,
mOpenCLBackend->getOpenCLRuntime());
#endif
#ifdef LOG_VERBOSE
MNN_PRINT("end SelectBufExecution onExecute...");
#endif
return NO_ERROR;
}
class SelectBufCreator : public OpenCLBackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
for (int i = 0; i < inputs.size(); ++i) {
TensorUtils::setTensorSupportPack(inputs[i], false);
}
for (int i = 0; i < outputs.size(); ++i) {
TensorUtils::setTensorSupportPack(outputs[i], false);
}
return new SelectBufExecution(backend);
}
};
OpenCLCreatorRegister<SelectBufCreator> __SelectBuf__(OpType_Select, BUFFER);
} // namespace OpenCL
} // namespace MNN
#endif /* MNN_OPENCL_BUFFER_CLOSED */