mirror of https://github.com/alibaba/MNN.git
471 lines
20 KiB
C++
471 lines
20 KiB
C++
//
|
|
// Arm82Convolution.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/01/07.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
#ifdef __aarch64__
|
|
#include "backend/arm82/Arm82Convolution.hpp"
|
|
#include "backend/arm82/Arm82Backend.hpp"
|
|
#include "backend/arm82/Arm82Convolution3x3.hpp"
|
|
#include "backend/arm82/Arm82OptFunc.hpp"
|
|
#include "core/Concurrency.h"
|
|
#include "core/Macro.h"
|
|
#include "core/TensorUtils.hpp"
|
|
#include "core/ConvolutionCommon.hpp"
|
|
|
|
#ifdef MNN_USE_NEON
|
|
#include <arm_neon.h>
|
|
#endif
|
|
|
|
namespace MNN {
|
|
|
|
#ifndef MNN_USE_NEON
|
|
static void MNNGemmFP16C8_UNIT(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 *weight, const FLOAT16 *bias,
|
|
size_t src_loop, size_t dst_step, size_t dst_loop, size_t relu, size_t relu6,
|
|
size_t realDstCount) {
|
|
const auto dst_step_tmp = dst_step / sizeof(FLOAT16);
|
|
|
|
for (int dz = 0; dz < dst_loop; ++dz) {
|
|
const auto weight_dz = weight + dz * src_loop * (ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT);
|
|
const auto bias_dz = bias + dz * ARMV82_CHANNEL_UNIT;
|
|
auto dst_z = dst + dz * dst_step_tmp;
|
|
for (int w = 0; w < DST_XUNIT; ++w) {
|
|
const auto src_x = src + w * ARMV82_CHANNEL_UNIT;
|
|
auto dst_x = dst_z + w * ARMV82_CHANNEL_UNIT;
|
|
FLOAT16 dstTemp[ARMV82_CHANNEL_UNIT];
|
|
|
|
memcpy(dstTemp, bias_dz, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
|
|
|
// MAC
|
|
for (int sz = 0; sz < src_loop; ++sz) {
|
|
const auto weight_sz = weight_dz + (ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT) * sz;
|
|
const auto src_z = src_x + sz * DST_XUNIT * ARMV82_CHANNEL_UNIT;
|
|
|
|
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
|
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
|
dstTemp[j] += src_z[i] * weight_sz[i * ARMV82_CHANNEL_UNIT + j];
|
|
}
|
|
}
|
|
} // end MAC
|
|
|
|
if (relu) {
|
|
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
|
if (dstTemp[j] < 0) {
|
|
dstTemp[j] = 0;
|
|
}
|
|
}
|
|
}
|
|
if (relu6) {
|
|
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
|
if (dstTemp[j] < 0) {
|
|
dstTemp[j] = 0;
|
|
}
|
|
if (dstTemp[j] > 6) {
|
|
dstTemp[j] = 6.0;
|
|
}
|
|
}
|
|
}
|
|
|
|
memcpy(dst_x, dstTemp, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
static void Im2ColTransformer(FLOAT16 *dst, const FLOAT16 *src, ConvolutionCommon::Im2ColParameter *im2colParam,
|
|
size_t xIndexStart, size_t realDstCount) {
|
|
{
|
|
const int colBufferSize = im2colParam->kernelCountUnit * DST_XUNIT * ARMV82_CHANNEL_UNIT * sizeof(FLOAT16);
|
|
memset(dst, 0, colBufferSize);
|
|
}
|
|
// src data format is nc8hw8
|
|
|
|
const auto ih = im2colParam->ih;
|
|
const auto iw = im2colParam->iw;
|
|
// const auto oh = im2colParameter->oh;
|
|
const auto ow = im2colParam->ow;
|
|
const auto kh = im2colParam->kernelY;
|
|
const auto kw = im2colParam->kernelX;
|
|
const auto dilateX = im2colParam->dilateX;
|
|
const auto dilateY = im2colParam->dilateY;
|
|
const auto icDiv4 = im2colParam->icDiv4;
|
|
const auto srcChannleStride = iw * ih * ARMV82_CHANNEL_UNIT;
|
|
const auto stridex = im2colParam->strideX;
|
|
const auto stridey = im2colParam->strideY;
|
|
const auto padx = im2colParam->padX;
|
|
const auto pady = im2colParam->padY;
|
|
constexpr int dstXStep = ARMV82_CHANNEL_UNIT * DST_XUNIT;
|
|
|
|
for (int i = 0; i < realDstCount; ++i) {
|
|
int xIndex = (int)xIndexStart + i;
|
|
int ox = xIndex % ow;
|
|
int oy = xIndex / ow;
|
|
int sx = ox * stridex - padx;
|
|
int sy = oy * stridey - pady;
|
|
int sfy = ALIMAX(0, (UP_DIV(-sy, dilateY)));
|
|
int efy = ALIMIN(kh, UP_DIV(ih - sy, dilateY));
|
|
int sfx = ALIMAX(0, (UP_DIV(-sx, dilateX)));
|
|
int efx = ALIMIN(kw, UP_DIV(iw - sx, dilateX));
|
|
int fyC = efy - sfy;
|
|
int fxC = efx - sfx;
|
|
|
|
auto colAddrI = dst + ARMV82_CHANNEL_UNIT * i;
|
|
auto inputOffset = src + (sx + sfx * dilateX + (sy + sfy * dilateY) * iw) * ARMV82_CHANNEL_UNIT;
|
|
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
|
|
|
for (int fy = 0; fy < fyC; ++fy) {
|
|
for (int fx = 0; fx < fxC; ++fx) {
|
|
auto inputUnit = inputOffset + (fx * dilateX + fy * dilateY * iw) * ARMV82_CHANNEL_UNIT;
|
|
auto indexStart = (indexOffset + (fy * kw + fx) * icDiv4) * dstXStep;
|
|
for (int sz = 0; sz < icDiv4; ++sz) {
|
|
auto dstUnit = colAddrI + indexStart + sz * dstXStep;
|
|
memcpy(dstUnit, inputUnit, ARMV82_CHANNEL_UNIT * sizeof(FLOAT16));
|
|
inputUnit += srcChannleStride;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// shuffle channel
|
|
#ifdef MNN_USE_NEON
|
|
if (realDstCount > (DST_XUNIT / 2)) {
|
|
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 0);
|
|
} else {
|
|
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 1);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
static void Im2ColTransformer1x1(FLOAT16 *dst, const FLOAT16 *src, ConvolutionCommon::Im2ColParameter *im2colParam,
|
|
size_t xIndexStart, size_t realDstCount) {
|
|
{
|
|
const int colBufferSize = im2colParam->kernelCountUnit * DST_XUNIT * ARMV82_CHANNEL_UNIT * sizeof(FLOAT16);
|
|
memset(dst, 0, colBufferSize);
|
|
}
|
|
// src data format is nc8hw8
|
|
const auto ih = im2colParam->ih;
|
|
const auto iw = im2colParam->iw;
|
|
|
|
const auto icDiv8 = im2colParam->icDiv4;
|
|
const auto srcChannleStride = iw * ih * ARMV82_CHANNEL_UNIT;
|
|
constexpr int dstXStep = ARMV82_CHANNEL_UNIT * DST_XUNIT;
|
|
const auto srcStartPtr = src + xIndexStart * ARMV82_CHANNEL_UNIT;
|
|
|
|
for (int c = 0; c < icDiv8; ++c) {
|
|
memcpy(dst + c * dstXStep, srcStartPtr + c * srcChannleStride,
|
|
sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT * realDstCount);
|
|
}
|
|
|
|
// shuffle channel
|
|
#ifdef MNN_USE_NEON
|
|
if (realDstCount > (DST_XUNIT / 2)) {
|
|
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 0);
|
|
} else {
|
|
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 1);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
Arm82Convolution::Arm82Convolution(const MNN::Convolution2D *convParam, Backend *bn) : Execution(bn) {
|
|
const auto convCommon = convParam->common();
|
|
mCommon = convCommon;
|
|
const int kx = convCommon->kernelX();
|
|
const int ky = convCommon->kernelY();
|
|
const int kernelCount = kx * ky;
|
|
int inputChannel = convCommon->inputCount();
|
|
const int outputChannel = convCommon->outputCount();
|
|
if (inputChannel == 0) {
|
|
if (convParam->quanParameter()) {
|
|
inputChannel = convParam->quanParameter()->buffer()->size() / (2 * kernelCount * outputChannel);
|
|
} else {
|
|
inputChannel = convParam->weight()->size() / (kernelCount * outputChannel);
|
|
}
|
|
}
|
|
const int inputChannelUnit = UP_DIV(inputChannel, ARMV82_CHANNEL_UNIT);
|
|
const int outputChannelUnit = UP_DIV(outputChannel, ARMV82_CHANNEL_UNIT);
|
|
|
|
const int totalKernelCountUnit = kernelCount * inputChannelUnit;
|
|
mWeightFp16.reset(Tensor::createDevice<uint16_t>(
|
|
{outputChannelUnit, totalKernelCountUnit, ARMV82_CHANNEL_UNIT, ARMV82_CHANNEL_UNIT}));
|
|
auto allocRes = bn->onAcquireBuffer(mWeightFp16.get(), Backend::STATIC);
|
|
if (!allocRes) {
|
|
mValid = false;
|
|
return;
|
|
}
|
|
|
|
auto weightFp16DstPtr = mWeightFp16->host<FLOAT16>();
|
|
memset(weightFp16DstPtr, 0, mWeightFp16->size());
|
|
|
|
const FLOAT16 *fp16WeightPtr = nullptr;
|
|
std::vector<FLOAT16> weightFp16;
|
|
if (convParam->quanParameter()) {
|
|
MNN_ASSERT((convParam->quanParameter()->type() == 3) || (convParam->quanParameter()->type() == 4));
|
|
if (convParam->quanParameter()->type() == 3) {
|
|
// the data type of weight is fp16
|
|
fp16WeightPtr = reinterpret_cast<const FLOAT16 *>(convParam->quanParameter()->buffer()->data());
|
|
}
|
|
if (convParam->quanParameter()->type() == 4) {
|
|
std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
|
|
quanCommon = ConvolutionCommon::load(convParam->quanParameter(), true);
|
|
int weightCount = convParam->quanParameter()->buffer()->size();
|
|
weightFp16.resize(weightCount);
|
|
MNNQuantizeFP16(weightFp16.data(), quanCommon->weightFloat.get(), weightCount);
|
|
fp16WeightPtr = weightFp16.data();
|
|
}
|
|
} else {
|
|
// the data type of weight is fp32, then quantize weight to be fp16 data type
|
|
int size = convParam->weight()->size();
|
|
weightFp16.resize(size);
|
|
MNNQuantizeFP16(weightFp16.data(), convParam->weight()->data(), size);
|
|
fp16WeightPtr = weightFp16.data();
|
|
}
|
|
|
|
auto weightFp16SrcPtr = fp16WeightPtr;
|
|
|
|
const int oneChannleKernelSize = kernelCount * inputChannel;
|
|
|
|
#ifdef MNN_USE_NEON
|
|
int curOcChannel = 0;
|
|
auto reorderWeight = [&](int ocUnit, int ocUnitNum, const FLOAT16 *weightSrc, FLOAT16 *weightDst) {
|
|
for (int oc = 0; oc < ocUnitNum; ++oc) {
|
|
auto weightDstOcUnit = weightDst + oc * kernelCount * inputChannelUnit * ARMV82_CHANNEL_UNIT * ocUnit;
|
|
const auto weightSrcOc = weightSrc + oc * ocUnit * oneChannleKernelSize;
|
|
for (int k = 0; k < kernelCount; ++k) {
|
|
auto weightDstK = weightDstOcUnit + k * inputChannelUnit * ARMV82_CHANNEL_UNIT * ocUnit;
|
|
const auto weightSrcK = weightSrcOc + k;
|
|
for (int y = 0; y < inputChannel; ++y) {
|
|
const int yOutSide = y / ARMV82_CHANNEL_UNIT;
|
|
const int yInSide = y % ARMV82_CHANNEL_UNIT;
|
|
auto weightDstIc = weightDstK + yOutSide * ARMV82_CHANNEL_UNIT * ocUnit + yInSide * ocUnit;
|
|
const auto weigthSrcIc = weightSrcK + y * kernelCount;
|
|
|
|
for (int x = 0; x < ocUnit; ++x) {
|
|
if (curOcChannel + x < outputChannel) {
|
|
weightDstIc[x] = weigthSrcIc[x * oneChannleKernelSize];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
curOcChannel += ocUnit;
|
|
}
|
|
};
|
|
const int ocDivDoubleUnit = outputChannelUnit / 2;
|
|
// reorder weight in double ARMV82_CHANNEL_UNIT
|
|
reorderWeight((ARMV82_CHANNEL_UNIT * 2), ocDivDoubleUnit, weightFp16SrcPtr, weightFp16DstPtr);
|
|
auto weightRemainDst = weightFp16DstPtr + kernelCount * inputChannelUnit * ARMV82_CHANNEL_UNIT * ocDivDoubleUnit *
|
|
(ARMV82_CHANNEL_UNIT * 2);
|
|
auto weightRemainSrc = weightFp16SrcPtr + kernelCount * inputChannel * ocDivDoubleUnit * (ARMV82_CHANNEL_UNIT * 2);
|
|
if (outputChannelUnit % 2 == 1) {
|
|
// reorder weight in ARMV82_CHANNEL_UNIT
|
|
reorderWeight(ARMV82_CHANNEL_UNIT, 1, weightRemainSrc, weightRemainDst);
|
|
}
|
|
#else
|
|
// reorder weight
|
|
const int ocUnitStride = inputChannelUnit * ARMV82_CHANNEL_UNIT * kernelCount * ARMV82_CHANNEL_UNIT;
|
|
for (int k = 0; k < kernelCount; ++k) {
|
|
const auto weightSrcK = weightFp16SrcPtr + k;
|
|
auto weightDstK = weightFp16DstPtr + k * inputChannelUnit * ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT;
|
|
for (int y = 0; y < inputChannel; ++y) {
|
|
const int yOutSide = y / ARMV82_CHANNEL_UNIT;
|
|
const int yInSide = y % ARMV82_CHANNEL_UNIT;
|
|
|
|
auto dstY =
|
|
weightDstK + yOutSide * ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT + yInSide * ARMV82_CHANNEL_UNIT;
|
|
const auto srcY = weightSrcK + y * kernelCount;
|
|
for (int x = 0; x < outputChannel; ++x) {
|
|
const int xOutSide = x / ARMV82_CHANNEL_UNIT;
|
|
const int xInSide = x % ARMV82_CHANNEL_UNIT;
|
|
const int dstIndex = xOutSide * ocUnitStride + xInSide;
|
|
const int srcIndex = x * oneChannleKernelSize;
|
|
dstY[dstIndex] = srcY[srcIndex];
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
mBiasFp16.reset(Tensor::createDevice<uint16_t>({outputChannelUnit * ARMV82_CHANNEL_UNIT}));
|
|
allocRes = bn->onAcquireBuffer(mBiasFp16.get(), Backend::STATIC);
|
|
if (!allocRes) {
|
|
mValid = false;
|
|
return;
|
|
}
|
|
|
|
// TODO, bias is fp32, save bias also in fp16?
|
|
auto biasDstPtr = mBiasFp16->host<FLOAT16>();
|
|
memset(biasDstPtr, 0, mBiasFp16->size());
|
|
MNNQuantizeFP16(biasDstPtr, convParam->bias()->data(), outputChannel);
|
|
|
|
mIm2ColParamter.dilateX = convCommon->dilateX();
|
|
mIm2ColParamter.dilateY = convCommon->dilateY();
|
|
mIm2ColParamter.strideX = convCommon->strideX();
|
|
mIm2ColParamter.strideY = convCommon->strideY();
|
|
mIm2ColParamter.padX = convCommon->padX();
|
|
mIm2ColParamter.padY = convCommon->padY();
|
|
mIm2ColParamter.icDiv4 = inputChannelUnit;
|
|
mIm2ColParamter.kernelX = convCommon->kernelX();
|
|
mIm2ColParamter.kernelY = convCommon->kernelY();
|
|
mIm2ColParamter.kernelCountUnit = totalKernelCountUnit;
|
|
|
|
mRelu6 = convCommon->relu6();
|
|
mRelu = convCommon->relu();
|
|
}
|
|
|
|
Arm82Convolution::~Arm82Convolution() {
|
|
if (mWeightFp16 != nullptr) {
|
|
backend()->onReleaseBuffer(mWeightFp16.get(), Backend::STATIC);
|
|
}
|
|
if (mBiasFp16 != nullptr) {
|
|
backend()->onReleaseBuffer(mBiasFp16.get(), Backend::STATIC);
|
|
}
|
|
}
|
|
|
|
ErrorCode Arm82Convolution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
|
|
mIm2ColParamter.padX = mCommon->padX();
|
|
mIm2ColParamter.padY = mCommon->padY();
|
|
if (mCommon->padMode() == PadMode_SAME) {
|
|
int kernelWidthSize = (mCommon->kernelX() - 1) * mCommon->dilateX() + 1;
|
|
int kernelHeightSize = (mCommon->kernelY() - 1) * mCommon->dilateY() + 1;
|
|
|
|
int padNeededWidth = (output->width() - 1) * mCommon->strideX() + kernelWidthSize - input->width();
|
|
int padNeededHeight = (output->height() - 1) * mCommon->strideY() + kernelHeightSize - input->height();
|
|
mIm2ColParamter.padX = padNeededWidth / 2;
|
|
mIm2ColParamter.padY = padNeededHeight / 2;
|
|
}
|
|
|
|
mIm2ColParamter.ih = input->height();
|
|
mIm2ColParamter.iw = input->width();
|
|
mIm2ColParamter.oh = output->height();
|
|
mIm2ColParamter.ow = output->width();
|
|
|
|
mTileCount = UP_DIV(output->height() * output->width(), DST_XUNIT);
|
|
const int threads = std::max(1, static_cast<Arm82Backend *>(backend())->numberThread());
|
|
mThreadNums = std::min(threads, mTileCount);
|
|
|
|
mIm2ColBuffer.setType(DataType_DT_BFLOAT16);
|
|
mIm2ColBuffer.buffer().dimensions = 3;
|
|
mIm2ColBuffer.setLength(0, mThreadNums);
|
|
mIm2ColBuffer.setLength(1, DST_XUNIT);
|
|
mIm2ColBuffer.setLength(2, mWeightFp16->length(1) * ARMV82_CHANNEL_UNIT);
|
|
TensorUtils::setLinearLayout(&mIm2ColBuffer);
|
|
|
|
mRemainBuffer.setType(DataType_DT_BFLOAT16);
|
|
mRemainBuffer.buffer().dimensions = 3;
|
|
mRemainBuffer.setLength(0, mThreadNums);
|
|
mRemainBuffer.setLength(1, DST_XUNIT);
|
|
mRemainBuffer.setLength(2, UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT) * ARMV82_CHANNEL_UNIT);
|
|
TensorUtils::setLinearLayout(&mRemainBuffer);
|
|
bool success = backend()->onAcquireBuffer(&mIm2ColBuffer, Backend::DYNAMIC);
|
|
success = success && backend()->onAcquireBuffer(&mRemainBuffer, Backend::DYNAMIC);
|
|
if (!success) {
|
|
return OUT_OF_MEMORY;
|
|
}
|
|
|
|
backend()->onReleaseBuffer(&mIm2ColBuffer, Backend::DYNAMIC);
|
|
backend()->onReleaseBuffer(&mRemainBuffer, Backend::DYNAMIC);
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
ErrorCode Arm82Convolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
const int outputPlaneLen = output->height() * output->width();
|
|
|
|
const int dstZStep = outputPlaneLen * ARMV82_CHANNEL_UNIT;
|
|
const int batch = input->batch();
|
|
const int ocDiv8 = UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT);
|
|
const int kernelCountUnit = mIm2ColParamter.kernelCountUnit;
|
|
|
|
const auto inputDataPtr = input->host<FLOAT16>();
|
|
const auto weightDataPtr = mWeightFp16->host<FLOAT16>();
|
|
const auto biasDataPtr = mBiasFp16->host<FLOAT16>();
|
|
auto im2ColPtr = mIm2ColBuffer.host<FLOAT16>();
|
|
auto outputDataPtr = output->host<FLOAT16>();
|
|
auto remainDataPtr = mRemainBuffer.host<FLOAT16>();
|
|
|
|
auto im2ColProcess = Im2ColTransformer;
|
|
bool useFastIm2Col = mIm2ColParamter.kernelX == 1 && mIm2ColParamter.kernelY == 1 && mIm2ColParamter.strideX == 1 &&
|
|
mIm2ColParamter.strideY == 1 && mIm2ColParamter.padX == 0 && mIm2ColParamter.padY == 0;
|
|
|
|
if (useFastIm2Col) {
|
|
im2ColProcess = Im2ColTransformer1x1;
|
|
}
|
|
|
|
const int inBatchStride = ROUND_UP(input->channel(), ARMV82_CHANNEL_UNIT) * input->height() * input->width();
|
|
const int outBatchStride = ocDiv8 * dstZStep;
|
|
for (int bIndex = 0; bIndex < batch; ++bIndex) {
|
|
const auto srcBatchPtr = inputDataPtr + bIndex * inBatchStride;
|
|
auto dstBatchPtr = outputDataPtr + bIndex * outBatchStride;
|
|
|
|
auto threadFunction = [&](int tId) {
|
|
auto im2ColCurPtr = im2ColPtr + tId * mIm2ColBuffer.stride(0);
|
|
auto gemmOutputPtr = remainDataPtr + tId * mRemainBuffer.stride(0);
|
|
|
|
for (int tIndex = tId; tIndex < mTileCount; tIndex += mThreadNums) {
|
|
const int xIndexStart = tIndex * DST_XUNIT;
|
|
const int realDstCount = ALIMIN(outputPlaneLen - xIndexStart, DST_XUNIT);
|
|
|
|
Im2ColTransformer(im2ColCurPtr, srcBatchPtr, &mIm2ColParamter, xIndexStart, realDstCount);
|
|
|
|
auto outputCurTilePtr = dstBatchPtr + xIndexStart * ARMV82_CHANNEL_UNIT;
|
|
|
|
if (realDstCount == DST_XUNIT) {
|
|
// compute one tile
|
|
MNNGemmFP16C8_UNIT(outputCurTilePtr, im2ColCurPtr, weightDataPtr, biasDataPtr, kernelCountUnit,
|
|
dstZStep * sizeof(FLOAT16), ocDiv8, mRelu, mRelu6, realDstCount);
|
|
} else {
|
|
// compute the remain
|
|
MNNGemmFP16C8_UNIT(gemmOutputPtr, im2ColCurPtr, weightDataPtr, biasDataPtr, kernelCountUnit,
|
|
ARMV82_CHANNEL_UNIT * DST_XUNIT * sizeof(FLOAT16), ocDiv8, mRelu, mRelu6,
|
|
realDstCount);
|
|
for (int z = 0; z < ocDiv8; ++z) {
|
|
auto outputz = outputCurTilePtr + z * dstZStep;
|
|
auto srcz = gemmOutputPtr + z * ARMV82_CHANNEL_UNIT * DST_XUNIT;
|
|
memcpy(outputz, srcz, realDstCount * ARMV82_CHANNEL_UNIT * sizeof(FLOAT16));
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
MNN_CONCURRENCY_BEGIN(tId, mThreadNums)
|
|
threadFunction((int)tId);
|
|
#ifdef MNN_USE_THREAD_POOL
|
|
MNN_CONCURRENCY_END();
|
|
#else
|
|
MNN_CONCURRENCY_END();
|
|
#endif
|
|
}
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
class Arm82ConvolutionCreator : public Arm82Backend::Arm82Creator {
|
|
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
|
const MNN::Op *op, Backend *backend) const override {
|
|
auto convParam = op->main_as_Convolution2D();
|
|
// avoid other quantize method entry this creator
|
|
if(convParam->quanParameter() && convParam->quanParameter()->type() != 3){
|
|
return nullptr;
|
|
}
|
|
|
|
#ifdef __aarch64__
|
|
const auto param = convParam->common();
|
|
if (param->kernelX() == 3 && param->kernelY() == 3 && param->strideX() == 1 && param->strideY() == 1 &&
|
|
param->dilateX() == 1 && param->dilateY() == 1) {
|
|
return new Arm82Convolution3x3(convParam, backend);
|
|
}
|
|
#endif
|
|
return new Arm82Convolution(convParam, backend);
|
|
}
|
|
};
|
|
|
|
REGISTER_ARM82_OP_CREATOR(OpType_Convolution, Arm82ConvolutionCreator);
|
|
|
|
} // namespace MNN
|
|
|
|
#endif |