mirror of https://github.com/alibaba/MNN.git
284 lines
10 KiB
C++
284 lines
10 KiB
C++
//
|
|
// CPUArgMax.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2018/07/17.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "backend/cpu/CPUArgMax.hpp"
|
|
#include <float.h>
|
|
#include "backend/cpu/CPUBackend.hpp"
|
|
#include "backend/cpu/compute/CommonOptFunction.h"
|
|
#include "core/TensorUtils.hpp"
|
|
#include <vector>
|
|
|
|
namespace MNN {
|
|
|
|
CPUArgMax::CPUArgMax(Backend *backend, ArgMinOrMax mode, int topk, int outMaxVal, int softmaxThreshold, int axis)
|
|
: Execution(backend), mTopk(topk), mOutMaxVal(outMaxVal), mSoftmaxThreshold(softmaxThreshold), mAxis(axis), mMode(mode) {
|
|
// nothing to do
|
|
}
|
|
|
|
ErrorCode CPUArgMax::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
// acquire buffer space
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
auto inputDimensionFromat = TensorUtils::getDescribe(input)->dimensionFormat;
|
|
|
|
mFromNHWC = inputDimensionFromat != MNN_DATA_FORMAT_NC4HW4;
|
|
|
|
if (!mFromNHWC) {
|
|
// if the input format is NC4HW4, convert to be NCHW from NC4HW4 firstly
|
|
TensorUtils::copyShape(input, &mInputBuffer);
|
|
TensorUtils::copyShape(output, &mOutputBuffer);
|
|
|
|
backend()->onAcquireBuffer(&mInputBuffer, Backend::DYNAMIC);
|
|
backend()->onAcquireBuffer(&mOutputBuffer, Backend::DYNAMIC);
|
|
|
|
// release temp buffer space
|
|
backend()->onReleaseBuffer(&mInputBuffer, Backend::DYNAMIC);
|
|
backend()->onReleaseBuffer(&mOutputBuffer, Backend::DYNAMIC);
|
|
}
|
|
|
|
// compute params
|
|
mNum = 1;
|
|
mDim = 1;
|
|
mKeyExtent = 1;
|
|
|
|
if(mAxis < 0){
|
|
mAxis = mAxis + input->dimensions();
|
|
}
|
|
|
|
if (mFromNHWC) {
|
|
const int dimensions = input->dimensions();
|
|
for (int i = 0; i < mAxis; ++i) {
|
|
mNum = mNum * input->length(i);
|
|
}
|
|
mDim = input->length(mAxis);
|
|
for (int i = mAxis + 1; i < dimensions; ++i) {
|
|
mKeyExtent = mKeyExtent * input->length(i);
|
|
}
|
|
} else {
|
|
if (mAxis == 0) {
|
|
// Legacy code
|
|
// really legacy
|
|
int iw = input->width(), ow = output->width();
|
|
int ih = input->height(), oh = output->height();
|
|
int ic = input->channel(), oc = output->channel();
|
|
if (iw > 1) {
|
|
mNum = ic * ih;
|
|
mDim = iw;
|
|
mKeyExtent = ow;
|
|
} else if (ih > 1) { // iw = ow = 1
|
|
mNum = ic;
|
|
mDim = ih;
|
|
mKeyExtent = oh;
|
|
} else { // iw = ow = 1, ih = oh = 1;
|
|
mNum = 1;
|
|
mDim = ic;
|
|
mKeyExtent = oc;
|
|
}
|
|
// in caffe, axis may not exist, we set it to 10000 to indicate this situation
|
|
// see file: tools/converter/source/caffe/ArgMax.cpp
|
|
} else if (mAxis != 10000) {
|
|
const int dimensions = input->dimensions();
|
|
for (int i = 0; i < mAxis; ++i) {
|
|
mNum = mNum * input->length(i);
|
|
}
|
|
mDim = input->length(mAxis);
|
|
for (int i = mAxis + 1; i < dimensions; ++i) {
|
|
mKeyExtent = mKeyExtent * input->length(i);
|
|
}
|
|
} else {
|
|
MNN_PRINT("error in argmax, not implemented error.");
|
|
MNN_ASSERT(false);
|
|
}
|
|
}
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
ErrorCode CPUArgMax::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
|
|
using sortElementT = std::tuple<int, float>;
|
|
#define element_index(ele) (std::get<0>(ele))
|
|
#define element_value(ele) (std::get<1>(ele))
|
|
auto comp = [](const sortElementT &a, const sortElementT &b) -> int {
|
|
float va = element_value(a);
|
|
float vb = element_value(b);
|
|
return va > vb;
|
|
};
|
|
|
|
if (mFromNHWC) {
|
|
if (mMode == ARGMAX) {
|
|
auto srcOrigin = input->host<float>();
|
|
auto dstOrigin = output->host<int>();
|
|
for (int i = 0; i < mNum; ++i) {
|
|
auto iptr = srcOrigin + i * mDim * mKeyExtent;
|
|
auto optr = dstOrigin + i * mKeyExtent;
|
|
|
|
for(int k = 0; k < mKeyExtent; ++k){
|
|
int index = 0;
|
|
float maxValue = -FLT_MAX;
|
|
for (int j = 0; j < mDim; ++j) {
|
|
auto val = iptr[k + j * mKeyExtent];
|
|
if (val > maxValue) {
|
|
maxValue = val;
|
|
index = j;
|
|
}
|
|
}
|
|
optr[k] = index;
|
|
}
|
|
}
|
|
} else {
|
|
auto srcOrigin = input->host<float>();
|
|
auto dstOrigin = output->host<int>();
|
|
for (int i = 0; i < mNum; ++i) {
|
|
auto iptr = srcOrigin + i * mDim * mKeyExtent;
|
|
auto optr = dstOrigin + i * mKeyExtent;
|
|
|
|
for(int k = 0; k < mKeyExtent; ++k){
|
|
int index = 0;
|
|
float minValue = FLT_MAX;
|
|
for (int j = 0; j < mDim; ++j) {
|
|
auto val = iptr[k + j * mKeyExtent];
|
|
if (val < minValue) {
|
|
minValue = val;
|
|
index = j;
|
|
}
|
|
}
|
|
optr[k] = index;
|
|
}
|
|
}
|
|
}
|
|
|
|
} else {
|
|
MNN_ASSERT(mMode == ARGMAX); // caffe does not have argmin layer
|
|
// Legacy code for CAFFE
|
|
backend()->onCopyBuffer(input, &mInputBuffer);
|
|
|
|
// threshold
|
|
float softmaxThreshold = -FLT_MAX;
|
|
if (mSoftmaxThreshold) {
|
|
softmaxThreshold = 1.0f / mDim;
|
|
}
|
|
|
|
float *srcOrigin = mInputBuffer.host<float>(); // used as NCHW input
|
|
if (mAxis == 0) {
|
|
// really legacy
|
|
float *dstOrigin = mOutputBuffer.host<float>();
|
|
for (int i = 0; i < mNum; ++i) {
|
|
float *iptr = srcOrigin + i * mDim;
|
|
float *optr = dstOrigin + i * mKeyExtent;
|
|
|
|
// apply threshold
|
|
std::vector<sortElementT> vec;
|
|
vec.reserve(mDim);
|
|
for (int j = 0; j < mDim; ++j) {
|
|
float val = iptr[j];
|
|
if (val >= softmaxThreshold) {
|
|
vec.emplace_back(std::make_tuple(j, val));
|
|
}
|
|
}
|
|
size_t sortDim = vec.size();
|
|
|
|
// sort
|
|
|
|
int realTopK = std::min(mTopk, (int)sortDim);
|
|
|
|
std::partial_sort(vec.begin(), vec.begin() + realTopK, vec.end(), comp);
|
|
|
|
// copy index
|
|
for (int j = 0; j < mTopk; ++j) {
|
|
if (j < sortDim) {
|
|
optr[j] = element_index(vec[j]);
|
|
} else {
|
|
optr[j] = 0.f;
|
|
}
|
|
}
|
|
|
|
// copy max value
|
|
if (mOutMaxVal) {
|
|
for (int j = 0; j < mTopk; ++j) {
|
|
if (j < sortDim) {
|
|
optr[mTopk + j] = element_value(vec[j]);
|
|
} else {
|
|
optr[mTopk + j] = 0.f;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
backend()->onCopyBuffer(&mOutputBuffer, output);
|
|
} else {
|
|
float *dstOrigin = output->host<float>();
|
|
int outMaxValNum = mOutMaxVal + 1;
|
|
for (int i = 0; i < mNum; ++i) {
|
|
float *iptr = srcOrigin + i * mDim * mKeyExtent;
|
|
float *optr = dstOrigin + i * mKeyExtent * mTopk * outMaxValNum;
|
|
|
|
for (int k = 0; k < mKeyExtent; ++k) {
|
|
// apply threshold
|
|
std::vector<sortElementT> vec;
|
|
vec.reserve(mDim);
|
|
for (int j = 0; j < mDim; ++j) {
|
|
float val = iptr[k + j * mKeyExtent];
|
|
if (val >= softmaxThreshold) {
|
|
vec.emplace_back(std::make_tuple(j, val));
|
|
}
|
|
}
|
|
size_t sortDim = vec.size();
|
|
|
|
// sort
|
|
|
|
int realTopK = std::min(mTopk, (int) sortDim);
|
|
|
|
std::partial_sort(vec.begin(), vec.begin() + realTopK, vec.end(), comp);
|
|
|
|
// copy index
|
|
for (int j = 0; j < mTopk; ++j) {
|
|
if (j < sortDim) {
|
|
optr[k * outMaxValNum * mTopk + j] = element_index(vec[j]);
|
|
} else {
|
|
optr[k * outMaxValNum * mTopk + j] = 0.f;
|
|
}
|
|
}
|
|
|
|
// copy max value
|
|
if (mOutMaxVal) {
|
|
for (int j = 0; j < mTopk; ++j) {
|
|
if (j < sortDim) {
|
|
optr[k * outMaxValNum * mTopk + mTopk + j] = element_value(vec[j]);
|
|
} else {
|
|
optr[k * outMaxValNum * mTopk + mTopk + j] = 0.f;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
class CPUArgMaxCreator : public CPUBackend::Creator {
|
|
public:
|
|
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
|
const MNN::Op *op, Backend *backend) const {
|
|
auto argMax = op->main_as_ArgMax();
|
|
if (op->type() == OpType_ArgMin) {
|
|
return new CPUArgMax(backend, CPUArgMax::ArgMinOrMax::ARGMIN,
|
|
argMax->topK(), argMax->outMaxVal(), argMax->softmaxThreshold(), argMax->axis());
|
|
} else {
|
|
return new CPUArgMax(backend, CPUArgMax::ArgMinOrMax::ARGMAX,
|
|
argMax->topK(), argMax->outMaxVal(), argMax->softmaxThreshold(), argMax->axis());
|
|
}
|
|
}
|
|
};
|
|
REGISTER_CPU_OP_CREATOR(CPUArgMaxCreator, OpType_ArgMax);
|
|
REGISTER_CPU_OP_CREATOR(CPUArgMaxCreator, OpType_ArgMin);
|
|
} // namespace MNN
|