MNN/source/backend/cuda/execution/ArgMaxExecution.cu

80 lines
2.4 KiB
Plaintext

#include "ArgMaxExecution.hpp"
namespace MNN {
namespace CUDA {
template <typename T>
__global__ void ARGMAX(const int count, const int outside, const int inside, const int dim,
const T *input, T *output) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
const int o = i / inside;
const int n = i % inside;
T* outPtr = output + inside * o;
const T* inpPtr = input + inside * dim * o;
int index = 0;
T maxValue = inpPtr[0];
for(int j=1; j<dim; j++) {
T value = inpPtr[j*inside];
if(maxValue < value) {
index = j;
maxValue = value;
}
}
outPtr[n] = index;
}
return;
}
ArgMaxExecution::ArgMaxExecution(const Op* op, Backend *backend) : Execution(backend) {
mOp = op;
mAxis = mOp->main_as_ArgMax()->axis();
}
ArgMaxExecution::~ArgMaxExecution(){
// Do nothing
}
ErrorCode ArgMaxExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto input = inputs[0];
auto output = outputs[0];
if (mAxis < 0) {
mAxis = input->dimensions() + mAxis;
}
mInside = 1;
mOutside = 1;
for (int i=0; i<mAxis; ++i) {
mOutside *= input->length(i);
}
for (int i=mAxis+1; i<input->dimensions(); ++i) {
mInside *= input->length(i);
}
mDim = input->length(mAxis);
return NO_ERROR;
}
ErrorCode ArgMaxExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto runtime = static_cast<CUDABackend *>(backend())->getCUDARuntime();
auto input = (void *)inputs[0]->deviceId();
auto output = (void *)outputs[0]->deviceId();
int count = mOutside * mInside;
int block_num = runtime->blocks_num(count);
int thread_num = runtime->threads_num();
ARGMAX<<<block_num, thread_num>>>(count, mOutside, mInside, mDim, (const float*)input,(float *)output);
return NO_ERROR;
}
class ArgMaxCreator : public CUDABackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
return new ArgMaxExecution(op, backend);
}
};
static CUDACreatorRegister<ArgMaxCreator> __init(OpType_ArgMax);
}
}