MNN/source/backend/nnapi/execution/NNAPIPool.cpp

68 lines
2.2 KiB
C++
Raw Normal View History

2022-09-30 10:02:52 +08:00
//
// NNAPIPool.cpp
// MNN
//
// Created by MNN on 2022/09/06.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "NNAPIPool.hpp"
namespace MNN {
NNAPIPool::NNAPIPool(MNN::Backend *b, const MNN::Op *op, const std::vector<Tensor *> &inputs, const std::vector<MNN::Tensor *> &outputs) : NNAPICommonExecution(b, op) {
}
ErrorCode NNAPIPool::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto pool = mOp->main_as_Pool();
auto strideX = pool->strideX();
auto strideY = pool->strideY();
auto kernelX = pool->kernelX();
auto kernelY = pool->kernelY();
auto padMod = pool->padType();
auto global = pool->isGlobal();
int top, left, bottom, right;
if (nullptr != pool->pads()) {
MNN_ASSERT(pool->pads()->size() >= 4);
top = pool->pads()->Get(0);
left = pool->pads()->Get(1);
bottom = pool->pads()->Get(2);
right = pool->pads()->Get(3);
} else {
top = pool->padY();
left = pool->padX();
bottom = pool->padY();
right = pool->padX();
}
if (global) {
kernelX = inputs[0]->width();
kernelY = inputs[0]->height();
}
// NNAPI Pool inputs: [input, pad_left, pad_right, pad_top, pad_bottom, stride_w, stride_h, kernel_w, kernel_h, fusecode, NCHW/NHWC]
auto inputIdxs = getTensorIdxs(inputs);
// pad
inputIdxs.push_back(buildScalar(left));
inputIdxs.push_back(buildScalar(right));
inputIdxs.push_back(buildScalar(top));
inputIdxs.push_back(buildScalar(bottom));
// stride
inputIdxs.push_back(buildScalar(strideX));
inputIdxs.push_back(buildScalar(strideY));
// kernel
inputIdxs.push_back(buildScalar(kernelX));
inputIdxs.push_back(buildScalar(kernelY));
// fusecode
inputIdxs.push_back(buildScalar(ANEURALNETWORKS_FUSED_NONE));
// NCHW/NHWC
inputIdxs.push_back(buildScalar(mNCHW));
auto op = ANEURALNETWORKS_MAX_POOL_2D;
if (pool->type() == PoolType_AVEPOOL) {
op = ANEURALNETWORKS_AVERAGE_POOL_2D;
}
return buildOperation(op, inputIdxs, getTensorIdxs(outputs));
}
REGISTER_NNAPI_OP_CREATOR(NNAPIPool, OpType_Pooling)
} // namespace MNN