MNN/source/backend/cpu/CPUDeconvolution.cpp

276 lines
12 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// CPUDeconvolution.cpp
// MNN
//
// Created by MNN on 2018/07/20.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-02-26 09:57:17 +08:00
#include "CPUDeconvolution.hpp"
2019-12-27 22:16:57 +08:00
#include "core/BufferAllocator.hpp"
2020-02-26 09:57:17 +08:00
#include "CPUBackend.hpp"
2019-12-27 22:16:57 +08:00
#include "core/Concurrency.h"
#include "core/Macro.h"
2021-04-08 15:34:23 +08:00
#include "core/AutoStorage.h"
2019-12-27 22:16:57 +08:00
#include "math/Matrix.hpp"
#include "core/TensorUtils.hpp"
#include "core/ConvolutionCommon.hpp"
2020-02-26 09:57:17 +08:00
#include "compute/CommonOptFunction.h"
#include "compute/ConvOpt.h"
#include "compute/DeconvolutionWithStride.hpp"
//#define MNN_OPEN_TIME_TRACE
2019-12-27 22:16:57 +08:00
#include <MNN/AutoTime.hpp>
2019-04-17 10:49:11 +08:00
namespace MNN {
CPUDeconvolutionBasic::CPUDeconvolutionBasic(const Tensor* input, const Op* convOp, Backend* b)
2019-04-17 10:49:11 +08:00
: CPUConvolution(convOp->main_as_Convolution2D()->common(), b) {
mSrcCount = input->channel();
2021-04-08 15:34:23 +08:00
mPostParameters = getPostParameters();
2019-04-17 10:49:11 +08:00
}
ErrorCode CPUDeconvolutionBasic::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
2019-04-17 10:49:11 +08:00
auto input = inputs[0];
auto output = outputs[0];
auto pad = ConvolutionCommon::convolutionTransposePad(input, output, mCommon);
mPadY = pad.second;
mPadX = pad.first;
2019-04-17 10:49:11 +08:00
return NO_ERROR;
}
CPUDeconvolutionCommon::CPUDeconvolutionCommon(const Tensor* input, const Op* convOp, Backend* b)
: CPUDeconvolutionBasic(input, convOp, b) {
auto conv2D = convOp->main_as_Convolution2D();
int outputCount = mCommon->outputCount();
2021-04-08 15:34:23 +08:00
auto core = static_cast<CPUBackend*>(b)->functions();
mBias.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, core->pack) * core->pack}));
bool success = b->onAcquireBuffer(mBias.get(), Backend::STATIC);
if (!success) {
mValid = false;
return;
}
2021-04-08 15:34:23 +08:00
::memset(mBias->host<float>(), 0, mBias->length(0) * core->bytes);
if (core->bytes == 4) {
::memcpy(mBias->host<float>(), conv2D->bias()->data(), conv2D->bias()->size() * sizeof(float));
} else {
core->MNNFp32ToLowp(conv2D->bias()->data(), mBias->host<int16_t>(), conv2D->bias()->size());
}
}
CPUDeconvolutionCommon::~CPUDeconvolutionCommon() {
backend()->onReleaseBuffer(mBias.get(), Backend::STATIC);
}
2021-04-08 15:34:23 +08:00
static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outputCount, int srcCount, int fh, int fw,
uint8_t* cache, const CoreFunctions* core) {
auto outputC4 = UP_DIV(outputCount, core->pack);
2020-07-04 01:21:30 +08:00
// c, n, h, w-> c, n/4 * 4, h, w
for (int c=0; c<srcCount; ++c) {
2021-04-08 15:34:23 +08:00
auto dst = cache + c * outputC4 * fw * fh * core->pack * core->bytes;
auto src = tempWeight + c * outputCount * fw * fh * core->bytes;
core->MNNPackCUnit((float*)dst, (const float*)src, fw*fh, outputCount);
}
2020-07-04 01:21:30 +08:00
//printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw);
2021-04-08 15:34:23 +08:00
core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false);
}
CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend)
: MNN::CPUDeconvolutionCommon(input, convOp, backend) {
2019-04-17 10:49:11 +08:00
auto layer = convOp->main_as_Convolution2D()->common();
2021-04-08 15:34:23 +08:00
auto core = static_cast<CPUBackend*>(backend)->functions();
2020-11-05 16:41:56 +08:00
const float* tempWeight = nullptr;
int tempWeightSize = 0;
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
ConvolutionCommon::getConvParameters(&quanCommon, convOp->main_as_Convolution2D(), &tempWeight, &tempWeightSize);
2019-04-17 10:49:11 +08:00
int fw = layer->kernelX();
int fh = layer->kernelY();
int srcCount = mSrcCount;
2020-07-04 01:21:30 +08:00
int eP, lP, hP;
2021-04-08 15:34:23 +08:00
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh;
mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
2020-07-04 01:21:30 +08:00
std::shared_ptr<Tensor> cache(Tensor::createDevice<float>({outputAlign * srcCount}));
bool success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC) &&
backend->onAcquireBuffer(cache.get(), Backend::STATIC);
2019-04-17 10:49:11 +08:00
if (!success) {
mValid = false;
return;
}
2021-04-08 15:34:23 +08:00
auto dest = mWeight->host<uint8_t>();
2019-04-17 10:49:11 +08:00
int outputCount = layer->outputCount();
2021-04-08 15:34:23 +08:00
AutoStorage<uint8_t> lowpWeight;
if (core->bytes < 4) {
lowpWeight.reset(outputCount * srcCount * fh * fw * core->bytes);
if (lowpWeight.get() == nullptr) {
mValid = false;
return;
}
core->MNNFp32ToLowp(tempWeight, (int16_t*)lowpWeight.get(), outputCount * srcCount * fh * fw);
tempWeight = (float*)lowpWeight.get();
}
_transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host<uint8_t>(), core);
backend->onReleaseBuffer(cache.get(), Backend::STATIC);
mOrigin.reset(new CPUDeconvolutionOrigin(input, convOp, backend));
2019-04-17 10:49:11 +08:00
}
CPUDeconvolution::~CPUDeconvolution() {
backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC);
}
ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
CPUDeconvolutionBasic::onResize(inputs, outputs);
2021-04-08 15:34:23 +08:00
auto core = static_cast<CPUBackend*>(backend())->functions();
2019-04-17 10:49:11 +08:00
auto input = inputs[0];
auto output = outputs[0];
auto oc = output->channel();
2021-04-08 15:34:23 +08:00
if (UP_DIV(oc, core->pack) * core->pack != inputs[2]->length(0)) {
return INPUT_DATA_ERROR;
}
2019-04-17 10:49:11 +08:00
2021-04-08 15:34:23 +08:00
auto ocC4 = UP_DIV(output->channel(), core->pack);
auto icC4 = UP_DIV(input->channel(), core->pack);
auto kw = mCommon->kernelX();
auto kh = mCommon->kernelY();
auto dilateX = mCommon->dilateX();
auto dilateY = mCommon->dilateY();
auto strideX = mCommon->strideX();
auto strideY = mCommon->strideY();
auto padX = mPadX;
auto padY = mPadY;
auto width = input->width();
auto height = input->height();
auto src_height = output->height();
auto src_width = output->width();
2019-04-17 10:49:11 +08:00
auto kernelCount = ocC4 * mCommon->kernelX() * mCommon->kernelY();
2020-02-26 09:57:17 +08:00
mPreFunctions.clear();
mPostFunctions.clear();
auto plane = width * height;
const int maxDepth = 5;
2021-04-08 15:34:23 +08:00
AutoRelease<Tensor> tempColTotalBuffer(Tensor::createDevice<float>({kernelCount, plane, core->pack}));
auto res = backend()->onAcquireBuffer(tempColTotalBuffer.get(), Backend::DYNAMIC);
2019-04-17 10:49:11 +08:00
if (!res) {
return OUT_OF_MEMORY;
}
auto colBufferPtr = tempColTotalBuffer->host<float>();
auto biasPtr = inputs[2]->host<float>();
2020-02-26 09:57:17 +08:00
auto inputPtr = input->host<float>();
2021-04-08 15:34:23 +08:00
AutoRelease<Tensor> tempInputBuffer(
Tensor::create<float>({icC4, plane, core->pack}, inputPtr));
AutoRelease<Tensor> tempInput(Tensor::createDevice<float>({icC4, plane, core->pack}));
2020-02-26 09:57:17 +08:00
auto threadNumber = ((CPUBackend*)backend())->threadNumber();
2020-07-04 01:21:30 +08:00
if (input->batch() != 1) {
2020-02-26 09:57:17 +08:00
res = backend()->onAcquireBuffer(tempInput.get(), Backend::DYNAMIC);
if (!res) {
return OUT_OF_MEMORY;
}
2021-04-08 15:34:23 +08:00
auto newInputPtr = tempInput->host<uint8_t>();
2020-07-04 01:21:30 +08:00
// Copy Batch
2021-04-08 15:34:23 +08:00
mPreFunctions.emplace_back(std::make_pair([newInputPtr, icC4, plane, threadNumber, core](const float* srcBatch, int tId) {
2020-07-04 01:21:30 +08:00
for (int c = tId; c<icC4; c+=threadNumber) {
2021-04-08 15:34:23 +08:00
auto srcDepth = ((uint8_t*)srcBatch) + c * plane * core->pack * core->bytes;
auto dstDepth = newInputPtr + c * plane * core->pack * core->bytes;
::memcpy(dstDepth, srcDepth, plane * core->pack * core->bytes);
2020-02-26 09:57:17 +08:00
}
}, threadNumber));
} else {
tempInput->buffer().host = (uint8_t*)inputPtr;
}
mMatMul.reset(new StrassenMatrixComputor(backend(), true, maxDepth));
2020-07-04 01:21:30 +08:00
mMatMul->onEncode({tempInput.get(), inputs[1]}, {tempColTotalBuffer.get()});
mPostFunctions.emplace_back(std::make_pair([colBufferPtr, ocC4, width, height, kh, kw, padY, padX, dilateY, dilateX, strideY,
2021-04-08 15:34:23 +08:00
strideX, threadNumber, src_width, src_height, plane, biasPtr, this, core](float* outputPtr, int tId) {
auto unitBytes = core->pack * core->bytes;
2020-02-26 09:57:17 +08:00
for (int z = (tId); z < ocC4; z += threadNumber) {
2021-04-08 15:34:23 +08:00
auto dstZ = (uint8_t*)outputPtr + z * src_height * src_width * unitBytes;
auto srcZ = (uint8_t*)colBufferPtr + kw * kh * plane * z * unitBytes;
2020-07-04 01:21:30 +08:00
auto dstB = dstZ;
2021-04-08 15:34:23 +08:00
::memset(dstB, 0, src_width * src_height * unitBytes);
2020-07-04 01:21:30 +08:00
auto srcB = srcZ;
for (int oy = 0; oy < height; ++oy) {
for (int ox = 0; ox < width; ++ox) {
int srcStartX = ox * strideX - padX;
int srcStartY = oy * strideY - padY;
2019-04-17 10:49:11 +08:00
2020-07-04 01:21:30 +08:00
int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY)));
int efy = ALIMIN(kh, UP_DIV(src_height - srcStartY, dilateY));
2019-04-17 10:49:11 +08:00
2020-07-04 01:21:30 +08:00
int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX)));
int efx = ALIMIN(kw, UP_DIV(src_width - srcStartX, dilateX));
2019-04-17 10:49:11 +08:00
2021-04-08 15:34:23 +08:00
auto dstStart = dstB + srcStartX * unitBytes + srcStartY * src_width * unitBytes;
auto srcStart = srcB + unitBytes * (ox + oy * width);
if (sfy >= efy || sfx >= efx) {
continue;
}
2019-04-17 10:49:11 +08:00
2020-07-04 01:21:30 +08:00
for (int fy = sfy; fy < efy; ++fy) {
2021-04-08 15:34:23 +08:00
auto dstY = dstStart + fy * unitBytes * dilateY * src_width;
auto srcY = srcStart + fy * kw * plane * unitBytes;
core->MNNAddC4WithStride((const float*)(srcY + sfx * plane * unitBytes), (float*)(dstY + sfx * dilateX * unitBytes), plane * core->pack, dilateX * core->pack, efx - sfx);
2019-04-17 10:49:11 +08:00
}
}
}
2021-04-08 15:34:23 +08:00
core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr + unitBytes * z), src_height * src_width, 0, 0, 1, mPostParameters.data());
2020-02-26 09:57:17 +08:00
}
}, threadNumber));
if (tempInput->host<float>() != inputPtr) {
backend()->onReleaseBuffer(tempInput.get(), Backend::DYNAMIC);
}
backend()->onReleaseBuffer(tempColTotalBuffer.get(), Backend::DYNAMIC);
2019-04-17 10:49:11 +08:00
return NO_ERROR;
}
ErrorCode CPUDeconvolutionOrigin::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
2020-07-04 01:21:30 +08:00
auto batch = inputs[0]->batch();
2021-04-08 15:34:23 +08:00
auto core = static_cast<CPUBackend*>(backend())->functions();
auto input = inputs[0];
auto output = outputs[0];
auto oc = output->channel();
auto ocC4 = UP_DIV(output->channel(), core->pack);
auto icC4 = UP_DIV(input->channel(), core->pack);
auto width = output->width();
auto height = output->height();
auto src_height = input->height();
auto src_width = input->width();
2020-07-04 01:21:30 +08:00
for (int i=0; i<batch; ++i) {
2021-04-08 15:34:23 +08:00
auto inputPtr = inputs[0]->host<uint8_t>() + i * src_width * src_height * icC4 * core->pack * core->bytes;
auto outputPtr = outputs[0]->host<uint8_t>() + i * width * height * ocC4 * core->pack * core->bytes;
2020-07-04 01:21:30 +08:00
for (auto& unit : mPreFunctions) {
MNN_CONCURRENCY_BEGIN(tId, unit.second) {
2021-04-08 15:34:23 +08:00
unit.first((float*)inputPtr, (int)tId);
2020-07-04 01:21:30 +08:00
}
MNN_CONCURRENCY_END();
}
2020-07-04 01:21:30 +08:00
mMatMul->onExecute();
for (auto& unit : mPostFunctions) {
MNN_CONCURRENCY_BEGIN(tId, unit.second) {
2021-04-08 15:34:23 +08:00
unit.first((float*)outputPtr, (int)tId);
2020-07-04 01:21:30 +08:00
}
MNN_CONCURRENCY_END();
2019-04-17 10:49:11 +08:00
}
}
return NO_ERROR;
}
class CPUDeconvolutionCreator : public CPUBackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const {
auto convOp = op->main_as_Convolution2D();
auto common = convOp->common();
2021-04-08 15:34:23 +08:00
if (backend->type() == MNN_FORWARD_CPU) {
if (common->strideY() > 1 || common->strideX() > 1) {
if (common->dilateX() == 1 && common->dilateY() == 1) {
return new DeconvolutionWithStride(inputs[0], op, backend);
}
2019-04-17 10:49:11 +08:00
}
}
return new CPUDeconvolution(inputs[0], op, backend);
2019-04-17 10:49:11 +08:00
}
};
REGISTER_CPU_OP_CREATOR(CPUDeconvolutionCreator, OpType_Deconvolution);
} // namespace MNN