2019-04-17 10:49:11 +08:00
|
|
|
//
|
|
|
|
// Convolution1x1Strassen.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2019/02/12.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2020-02-26 09:57:17 +08:00
|
|
|
#include "Convolution1x1Strassen.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
#include <string.h>
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/BufferAllocator.hpp"
|
|
|
|
#include "backend/cpu/CPUBackend.hpp"
|
2020-02-26 09:57:17 +08:00
|
|
|
#include "CommonOptFunction.h"
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/Concurrency.h"
|
2020-02-26 09:57:17 +08:00
|
|
|
#include "ConvOpt.h"
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/Macro.h"
|
2019-04-17 10:49:11 +08:00
|
|
|
namespace MNN {
|
2020-05-17 23:09:45 +08:00
|
|
|
void Convolution1x1Strassen::_init(const Convolution2DCommon *common, Backend *b, const float *originWeight, size_t originWeightSize, const float *bias, size_t biasSize) {
|
2019-04-17 10:49:11 +08:00
|
|
|
mPostFunction = CPUConvolution::getPostFunction();
|
|
|
|
auto outputCount = (int)biasSize;
|
|
|
|
auto mSrcCount = (int)originWeightSize / outputCount;
|
2020-05-17 23:09:45 +08:00
|
|
|
int ePack, lPack, hPack;
|
|
|
|
MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
|
|
|
|
mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, hPack), UP_DIV(mSrcCount, lPack), lPack * hPack}));
|
|
|
|
mValid = b->onAcquireBuffer(mWeight.get(), Backend::STATIC);
|
2019-04-17 10:49:11 +08:00
|
|
|
if (!mValid) {
|
|
|
|
MNN_ERROR("Not Enough Memory\n");
|
|
|
|
return;
|
|
|
|
}
|
2020-05-17 23:09:45 +08:00
|
|
|
MNNPackForMatMul_B(mWeight->host<float>(), originWeight, outputCount, mSrcCount, true);
|
2019-04-17 10:49:11 +08:00
|
|
|
mBias.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, 4), 4}));
|
|
|
|
mValid = b->onAcquireBuffer(mBias.get(), Backend::STATIC);
|
|
|
|
if (!mValid) {
|
|
|
|
MNN_ERROR("Not Enough Memory\n");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
::memset(mBias->host<float>(), 0, mBias->size());
|
|
|
|
::memcpy(mBias->host<float>(), bias, biasSize * sizeof(float));
|
2020-05-17 23:09:45 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
Convolution1x1Strassen::Convolution1x1Strassen(const Convolution2DCommon *common, Backend *b, const float *originWeight,
|
|
|
|
size_t originWeightSize, const float *bias, size_t biasSize)
|
|
|
|
: CPUConvolution(common, b) {
|
|
|
|
_init(common, b, originWeight, originWeightSize, bias, biasSize);
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
Convolution1x1Strassen::~Convolution1x1Strassen() {
|
|
|
|
if (nullptr != mWeight) {
|
|
|
|
backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC);
|
|
|
|
}
|
|
|
|
backend()->onReleaseBuffer(mBias.get(), Backend::STATIC);
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode Convolution1x1Strassen::onReleaseCache() {
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode Convolution1x1Strassen::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
|
|
CPUConvolution::onResize(inputs, outputs);
|
|
|
|
auto input = inputs[0];
|
|
|
|
auto output = outputs[0];
|
2020-05-17 23:09:45 +08:00
|
|
|
auto ic = input->channel();
|
|
|
|
auto oc = output->channel();
|
|
|
|
auto l = ic;
|
|
|
|
auto h = oc;
|
|
|
|
int ePack, lPack, hPack;
|
|
|
|
MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
|
2020-05-19 13:40:35 +08:00
|
|
|
auto threadNumber = static_cast<CPUBackend*>(backend())->threadNumber();
|
|
|
|
mTempInputPack.reset(Tensor::createDevice<float>({threadNumber, UP_DIV(l, lPack), ePack * lPack}));
|
|
|
|
mTempOutputPack.reset(Tensor::createDevice<float>({threadNumber, UP_DIV(h, hPack), ePack * hPack}));
|
2020-05-18 11:07:41 +08:00
|
|
|
bool res = true;
|
2020-05-17 23:09:45 +08:00
|
|
|
res = res && backend()->onAcquireBuffer(mTempInputPack.get(), Backend::DYNAMIC);
|
|
|
|
res = res && backend()->onAcquireBuffer(mTempOutputPack.get(), Backend::DYNAMIC);
|
|
|
|
|
|
|
|
if (!res) {
|
|
|
|
return OUT_OF_MEMORY;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2020-05-19 13:40:35 +08:00
|
|
|
mParameters.resize(6);
|
|
|
|
mParameters[0] = 1;
|
|
|
|
mParameters[1] = UP_DIV(l, lPack);
|
|
|
|
mParameters[2] = UP_DIV(h, hPack);
|
|
|
|
mParameters[5] = 0;
|
2020-05-17 23:09:45 +08:00
|
|
|
res = res && backend()->onReleaseBuffer(mTempInputPack.get(), Backend::DYNAMIC);
|
|
|
|
res = res && backend()->onReleaseBuffer(mTempOutputPack.get(), Backend::DYNAMIC);
|
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode Convolution1x1Strassen::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
|
|
auto input = inputs[0];
|
|
|
|
auto output = outputs[0];
|
2020-05-17 23:09:45 +08:00
|
|
|
auto ic = input->channel();
|
|
|
|
auto oc = output->channel();
|
|
|
|
auto outputPlane = output->height() * output->width();
|
|
|
|
auto e = outputPlane;
|
|
|
|
auto l = ic;
|
|
|
|
auto h = oc;
|
|
|
|
auto ocC4 = UP_DIV(oc, 4);
|
2020-05-19 13:40:35 +08:00
|
|
|
int ePack, lPack, hPack;
|
|
|
|
MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
|
|
|
|
auto tileCount = UP_DIV(e, ePack);
|
|
|
|
auto threadNumber = static_cast<CPUBackend*>(backend())->threadNumber();
|
2020-05-17 23:09:45 +08:00
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
for (int batchIndex = 0; batchIndex < input->batch(); ++batchIndex) {
|
2020-05-19 13:40:35 +08:00
|
|
|
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
|
|
|
|
auto gemmSrc = mTempInputPack->host<float>() + tId * mTempInputPack->stride(0);
|
|
|
|
auto gemmDst = mTempOutputPack->host<float>() + tId * mTempOutputPack->stride(0);
|
|
|
|
for (int index=tId; index < tileCount; index += threadNumber) {
|
|
|
|
auto inputSrc = input->host<float>() + batchIndex * input->stride(0) + index * 4 * ePack;
|
|
|
|
auto eSize = std::min(e - index * ePack, ePack);
|
|
|
|
MNNPackC4ForMatMul_A(gemmSrc, inputSrc, eSize, l, e);
|
|
|
|
MNNPackedMatMul(gemmDst, gemmSrc, mWeight->host<float>(), mParameters.data());
|
|
|
|
auto outputSrc = output->host<float>() + batchIndex * output->stride(0) + index * 4 * ePack;
|
|
|
|
MNNUnPackC4ForMatMul_C(outputSrc, gemmDst, eSize, h, e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MNN_CONCURRENCY_END();
|
|
|
|
|
|
|
|
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
|
|
|
|
for (int dz=tId; dz<ocC4; dz+=threadNumber) {
|
|
|
|
mPostFunction(output->host<float>() + batchIndex * output->stride(0) + dz * outputPlane * 4, mBias->host<float>() + dz * 4, outputPlane, 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MNN_CONCURRENCY_END();
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
} // namespace MNN
|