MNN/source/backend/cuda/execution/PReLUExecution.cu

84 lines
3.4 KiB
Plaintext
Raw Normal View History

2020-11-05 16:41:56 +08:00
#include "PReLUExecution.hpp"
2022-02-18 11:30:27 +08:00
#include "MNNCUDADefine.hpp"
2020-11-05 16:41:56 +08:00
namespace MNN {
namespace CUDA {
#define CUDA_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
template<typename T>
__global__ void PRELU(const int n, const int channels, const int dim, const T* in, T* out,
2022-02-18 11:30:27 +08:00
const float* slopeData, int div_factor) {
CUDA_KERNEL_LOOP(t, n) {
int index = t / PACK_NUMBER;
int r = t % PACK_NUMBER;
2020-11-05 16:41:56 +08:00
int c = (index / dim) % channels / div_factor;
2022-02-18 11:30:27 +08:00
float iv = (float)in[t];
float ov = iv > 0.0 ? iv : iv * slopeData[c * PACK_NUMBER + r];
out[t] = (T)ov;
2020-11-05 16:41:56 +08:00
}
}
PReLUExecution::PReLUExecution(const PRelu* prelu, Backend *backend) : Execution(backend) {
int slopCount = prelu->slope()->size();
auto alphaData = prelu->slope()->data();
2022-02-18 11:30:27 +08:00
auto staticPool = static_cast<CUDABackend*>(backend)->getStaticBufferPool();
auto slopeSize = UP_DIV(slopCount, PACK_NUMBER) * PACK_NUMBER * sizeof(float);
mPreluStorage = staticPool->alloc(slopeSize);
mDeviceSlope = (uint8_t*)mPreluStorage.first + mPreluStorage.second;
2020-11-05 16:41:56 +08:00
MNN_ASSERT(nullptr != mDeviceSlope);
2022-02-18 11:30:27 +08:00
cudaMemset(mDeviceSlope, 0, slopeSize);
2020-11-05 16:41:56 +08:00
cudaMemcpy(mDeviceSlope, alphaData, slopCount * sizeof(float), cudaMemcpyHostToDevice);
mIsChannelShared = slopCount == 1;
}
PReLUExecution::~PReLUExecution() {
2022-02-18 11:30:27 +08:00
auto staticPool = static_cast<CUDABackend*>(backend())->getStaticBufferPool();
staticPool->free(mPreluStorage);
2020-11-05 16:41:56 +08:00
}
ErrorCode PReLUExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
MNN_ASSERT(inputs.size() == 1);
MNN_ASSERT(outputs.size() == 1);
auto input = inputs[0];
MNN_ASSERT(input->dimensions() >= 2);
2022-02-18 11:30:27 +08:00
mArea = input->length(0);
2020-11-05 16:41:56 +08:00
for (int i = 2; i < input->dimensions(); ++i) {
mArea *= input->length(i);
}
2022-02-18 11:30:27 +08:00
mChannel = UP_DIV(input->length(1), PACK_NUMBER);
mCount = mChannel*mArea * PACK_NUMBER;
2020-11-05 16:41:56 +08:00
//printf("mBatch:%d- mChannel:%d- mArea:%d- mCount:%d\n", mBatch,mChannel,mArea, mCount);
return NO_ERROR;
}
ErrorCode PReLUExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
2022-02-18 11:30:27 +08:00
auto bytes = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]);
2020-11-05 16:41:56 +08:00
int block_num = runtime->blocks_num(mCount);
int threads_num = runtime->threads_num();
auto input_addr = (void*)inputs[0]->deviceId();
auto output_addr = (void*)outputs[0]->deviceId();
int div_factor = mIsChannelShared ? mChannel : 1;
2022-02-18 11:30:27 +08:00
if (2 == bytes) {
PRELU<<<block_num, threads_num>>>(mCount, mChannel, mArea, (const half *)input_addr, (half *)output_addr,
(const float *)mDeviceSlope, div_factor);
} else {
PRELU<<<block_num, threads_num>>>(mCount, mChannel, mArea, (const float *)input_addr, (float *)output_addr,
(const float *)mDeviceSlope, div_factor);
}
2020-11-05 16:41:56 +08:00
return NO_ERROR;
}
class PReLUCreator : public CUDABackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
auto param = op->main_as_PRelu();
return new PReLUExecution(param, backend);
}
};
static CUDACreatorRegister<PReLUCreator> __init(OpType_PReLU);
}
}