| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ConvolutionWinograd.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2018/08/20.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "backend/cpu/compute/ConvolutionWinograd.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #include <math.h>
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "backend/cpu/compute/CommonOptFunction.h"
 | 
					
						
							|  |  |  | #include "core/Concurrency.h"
 | 
					
						
							|  |  |  | #include "backend/cpu/compute/ConvOpt.h"
 | 
					
						
							|  |  |  | #include "core/Macro.h"
 | 
					
						
							|  |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							|  |  |  | #include "math/WingoradGenerater.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #ifdef MNN_USE_NEON
 | 
					
						
							|  |  |  | #include <arm_neon.h>
 | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
											
												- dynamic computation graph (beta)
	- add supports (/express)
	- add tests
	- add benchmarks with it (/benchmark/exprModels)
- Python
	- MNN engine and tools were submitted to pip
	- available on Windows/macOS/Linux
- Engine/Converter
	- add supports for each op benchmarking
	- refactor optimizer by separating steps
- CPU
	- add supports for Conv3D, Pool3D, ELU, ReverseSequence
	- fix ArgMax, Permute, Scale, BinaryOp, Slice, SliceTf
- OpenCL
	- add half transform in CPU
	- add broadcast supports for binary
	- optimize Conv2D, Reshape, Eltwise, Gemm, etc.
- OpenGL
	- add sub, real div supports for binary
	- add supports for unary
	- optimize Conv2D, Reshape
- Vulkan
	- add max supports for eltwise
- Metal
	- fix metallib missing problem
- Train/Quantization
	- use express to refactor training codes
											
										 
											2019-09-26 21:02:07 +08:00
										 |  |  | #define CONVOLUTION_WINOGRAD_MAX_UNIT 8
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #define CONVOLUTION_WINOGRAD_MIN_UNIT 2
 | 
					
						
							|  |  |  | using namespace MNN::Math; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //#define MNN_WINOGRAD_PRINT_REDUCE_RATE
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | ConvolutionWinograd::ConvolutionWinograd(const Convolution2DCommon *convOp, const Tensor *input, const Tensor *output, | 
					
						
							|  |  |  |                                          Backend *b, const float *originWeight, size_t originWeightSize, | 
					
						
							|  |  |  |                                          const float *bias, size_t biasSize, int unit) | 
					
						
							|  |  |  |     : MNN::CPUConvolution(convOp, b) { | 
					
						
							|  |  |  |     mBias.reset(Tensor::createDevice<float>({ALIGN_UP4((int)biasSize)})); | 
					
						
							|  |  |  |     mValid = backend()->onAcquireBuffer(mBias.get(), Backend::STATIC); | 
					
						
							|  |  |  |     if (!mValid) { | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ::memset(mBias->host<float>(), 0, mBias->size()); | 
					
						
							|  |  |  |     ::memcpy(mBias->host<float>(), bias, biasSize * sizeof(float)); | 
					
						
							|  |  |  |     mTempBuffer.buffer().type         = halide_type_of<float>(); | 
					
						
							|  |  |  |     mTransformMidBuffer.buffer().type = halide_type_of<float>(); | 
					
						
							|  |  |  |     MNN_ASSERT(mCommon->kernelX() == mCommon->kernelY()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int threadNumber = ((CPUBackend *)backend())->threadNumber(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto kernelSize = mCommon->kernelY(); | 
					
						
							|  |  |  |     WinogradGenerater generator(unit, kernelSize); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int alpha        = unit + kernelSize - 1; | 
					
						
							|  |  |  |     int alpha2       = alpha * alpha; | 
					
						
							|  |  |  |     mSourceTransform = WinogradFunction::chooseSourceTransform(alpha, alpha); | 
					
						
							|  |  |  |     mDestTransform   = WinogradFunction::chooseDestTransform(alpha, unit); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int srcCount                       = input->channel(); | 
					
						
							|  |  |  |     int outputCount                    = output->channel(); | 
					
						
							|  |  |  |     mTempBuffer.buffer().dim[0].extent = threadNumber; | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |     mTempBuffer.buffer().dim[1].extent = CONVOLUTION_TILED_NUMBER; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     mTempBuffer.buffer().dim[2].extent = UP_DIV(srcCount, 4) + UP_DIV(outputCount, 4); | 
					
						
							|  |  |  |     mTempBuffer.buffer().dim[3].extent = 4 * alpha2; | 
					
						
							|  |  |  |     TensorUtils::setLinearLayout(&mTempBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mTransformMidBuffer.buffer().dim[0].extent = threadNumber; | 
					
						
							|  |  |  |     mTransformMidBuffer.buffer().dim[1].extent = 2; | 
					
						
							|  |  |  |     mTransformMidBuffer.buffer().dim[2].extent = alpha2; | 
					
						
							|  |  |  |     mTransformMidBuffer.buffer().dim[3].extent = 4; | 
					
						
							|  |  |  |     TensorUtils::setLinearLayout(&mTransformMidBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mA = generator.A(); | 
					
						
							|  |  |  |     mB = generator.B(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Transform Kernel
 | 
					
						
							|  |  |  |     auto G = generator.G(); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     std::shared_ptr<Tensor> sourceWeight(Tensor::create<float>( | 
					
						
							|  |  |  |         std::vector<int>{outputCount, srcCount, kernelSize, kernelSize}, (void *)originWeight, Tensor::CAFFE)); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     mWeight = generator.allocTransformWeight(sourceWeight.get(), 4, 4, false); | 
					
						
							|  |  |  |     mValid  = backend()->onAcquireBuffer(mWeight.get(), Backend::STATIC); | 
					
						
							|  |  |  |     if (!mValid) { | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     generator.transformWeight(mWeight.get(), sourceWeight.get()); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | ConvolutionWinograd::~ConvolutionWinograd() { | 
					
						
							|  |  |  |     if (nullptr != mBias) { | 
					
						
							|  |  |  |         backend()->onReleaseBuffer(mBias.get(), Backend::STATIC); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (nullptr != mWeight) { | 
					
						
							|  |  |  |         backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | 
					
						
							|  |  |  |     auto input   = inputs[0]; | 
					
						
							|  |  |  |     auto output  = outputs[0]; | 
					
						
							|  |  |  |     auto dstUnit = mA->length(1); | 
					
						
							|  |  |  |     auto srcUnit = mA->length(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto srcUnit2 = srcUnit * srcUnit; | 
					
						
							|  |  |  |     auto dstUnit2 = dstUnit * dstUnit; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int ow   = output->width(); | 
					
						
							|  |  |  |     int oh   = output->height(); | 
					
						
							|  |  |  |     int iw   = input->width(); | 
					
						
							|  |  |  |     int ih   = input->height(); | 
					
						
							|  |  |  |     int ic_4 = UP_DIV(input->channel(), 4); | 
					
						
							|  |  |  |     int dc_4 = UP_DIV(output->channel(), 4); | 
					
						
							|  |  |  |     // MNN_PRINT("%d, %d\n", srcUnit, dstUnit);
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int padY = mPadY; | 
					
						
							|  |  |  |     int padX = mPadX; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto wUnit = UP_DIV(ow, dstUnit); | 
					
						
							|  |  |  |     auto hUnit = UP_DIV(oh, dstUnit); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto totalCount   = wUnit * hUnit; | 
					
						
							|  |  |  |     auto postFunction = mPostFunction; | 
					
						
							|  |  |  |     // MNN_PRINT("ow=%d, oh=%d\n", ow, oh);
 | 
					
						
							|  |  |  |     int threadNumber = std::max(((CPUBackend *)backend())->threadNumber(), 1); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |     int tileCount    = UP_DIV(totalCount, CONVOLUTION_TILED_NUMBER); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     threadNumber     = std::min(threadNumber, tileCount); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int batchIndex = 0; batchIndex < input->batch(); ++batchIndex) { | 
					
						
							|  |  |  |         auto srcOrigin = input->host<float>() + batchIndex * input->stride(0); | 
					
						
							|  |  |  |         auto dstOrigin = output->host<float>() + batchIndex * output->stride(0); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-05-05 20:27:57 +08:00
										 |  |  |         auto weight    = mWeight->host<float>(); | 
					
						
							|  |  |  |         auto bias      = mBias->host<float>(); | 
					
						
							|  |  |  |         auto tFunction = [&](int tId) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             auto _srcOrigin = mTempBuffer.host<float>() + tId * mTempBuffer.stride(0); | 
					
						
							|  |  |  |             auto midBuffer0 = mTransformMidBuffer.host<float>() + tId * mTransformMidBuffer.stride(0); | 
					
						
							|  |  |  |             auto midBuffer1 = | 
					
						
							|  |  |  |                 mTransformMidBuffer.host<float>() + tId * mTransformMidBuffer.stride(0) + mTransformMidBuffer.stride(1); | 
					
						
							|  |  |  |             for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) { | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                 int xIndex  = (int)tIndex * CONVOLUTION_TILED_NUMBER; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 int xReamin = totalCount - xIndex; | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                 int xC      = xReamin > CONVOLUTION_TILED_NUMBER ? CONVOLUTION_TILED_NUMBER : xReamin; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 /*Source Transform Begin*/ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     int sourceZStep = iw * ih * 4; | 
					
						
							|  |  |  |                     int dstZStep    = xC * 4; | 
					
						
							|  |  |  |                     int unitStep    = ic_4 * xC * 4; | 
					
						
							|  |  |  |                     for (int xi = 0; xi < xC; ++xi) { | 
					
						
							|  |  |  |                         auto index = xIndex + xi; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         int wIndex = index % wUnit; | 
					
						
							|  |  |  |                         int hIndex = index / wUnit; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         int srcX  = wIndex * dstUnit - padX; | 
					
						
							|  |  |  |                         int srcY  = hIndex * dstUnit - padY; | 
					
						
							|  |  |  |                         int sy    = ALIMAX(0, srcY) - srcY; | 
					
						
							|  |  |  |                         int ey    = ALIMIN(srcY + srcUnit, ih) - srcY; | 
					
						
							|  |  |  |                         int sx    = ALIMAX(0, srcX) - srcX; | 
					
						
							|  |  |  |                         int ex    = ALIMIN(srcX + srcUnit, iw) - srcX; | 
					
						
							|  |  |  |                         int count = 4 * (ex - sx); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         auto dst_x = _srcOrigin + 4 * xi; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         auto srcStart = srcOrigin + (srcX + srcY * iw) * 4; | 
					
						
							|  |  |  |                         if (ex - sx == srcUnit && ey - sy == srcUnit) { | 
					
						
							|  |  |  |                             for (int z = 0; z < ic_4; ++z) { | 
					
						
							|  |  |  |                                 auto srcZ = srcStart + z * sourceZStep; | 
					
						
							|  |  |  |                                 // Transform
 | 
					
						
							|  |  |  |                                 for (int i = 0; i < srcUnit; ++i) { | 
					
						
							|  |  |  |                                     mSourceTransform(srcZ + 4 * i * iw, midBuffer1 + 4 * i, 4, 4 * srcUnit); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                                 auto dstZ = dst_x + z * dstZStep; | 
					
						
							|  |  |  |                                 for (int i = 0; i < srcUnit; ++i) { | 
					
						
							|  |  |  |                                     mSourceTransform(midBuffer1 + 4 * i * srcUnit, dstZ + i * unitStep, 4, | 
					
						
							|  |  |  |                                                      unitStep * srcUnit); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         } else { | 
					
						
							|  |  |  |                             for (int z = 0; z < ic_4; ++z) { | 
					
						
							|  |  |  |                                 // Extract
 | 
					
						
							|  |  |  |                                 auto srcZ = srcStart + z * sourceZStep; | 
					
						
							|  |  |  |                                 ::memset(midBuffer0, 0, mTransformMidBuffer.stride(1) * sizeof(float)); | 
					
						
							|  |  |  |                                 if (count > 0) { | 
					
						
							|  |  |  |                                     for (int yy = sy; yy < ey; ++yy) { | 
					
						
							|  |  |  |                                         auto dst_yy = midBuffer0 + yy * srcUnit * 4 + sx * 4; | 
					
						
							|  |  |  |                                         auto src_yy = srcZ + 4 * iw * yy + sx * 4; | 
					
						
							|  |  |  |                                         ::memcpy(dst_yy, src_yy, count * sizeof(float)); | 
					
						
							|  |  |  |                                     } | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                                 // Transform
 | 
					
						
							|  |  |  |                                 for (int i = 0; i < srcUnit; ++i) { | 
					
						
							|  |  |  |                                     mSourceTransform(midBuffer0 + 4 * i * srcUnit, midBuffer1 + 4 * i, 4, 4 * srcUnit); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                                 auto dstZ = dst_x + z * dstZStep; | 
					
						
							|  |  |  |                                 for (int i = 0; i < srcUnit; ++i) { | 
					
						
							|  |  |  |                                     mSourceTransform(midBuffer1 + 4 * i * srcUnit, dstZ + i * unitStep, 4, | 
					
						
							|  |  |  |                                                      unitStep * srcUnit); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 /*Source Transform End*/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 // Multi
 | 
					
						
							|  |  |  |                 auto _dstOrigin = _srcOrigin + xC * srcUnit2 * ic_4 * 4; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                 if (xC == CONVOLUTION_TILED_NUMBER) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                     for (int i = 0; i < srcUnit2; ++i) { | 
					
						
							|  |  |  |                         MNNGemmFloatUnit_4(_dstOrigin + i * dc_4 * 4 * xC, _srcOrigin + i * ic_4 * 4 * xC, | 
					
						
							|  |  |  |                                            weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0); | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } else { | 
					
						
							|  |  |  |                     for (int i = 0; i < srcUnit2; ++i) { | 
					
						
							|  |  |  |                         MNNGemmFloatCommon_4(_dstOrigin + i * dc_4 * 4 * xC, _srcOrigin + i * ic_4 * 4 * xC, | 
					
						
							|  |  |  |                                              weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, xC, 0); | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 /* Dest Transform And Post Treat Begin */ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     int dstZStep = ow * oh * 4; | 
					
						
							|  |  |  |                     int srcZStep = xC * 4; | 
					
						
							|  |  |  |                     int unitStep = dc_4 * xC * 4; | 
					
						
							|  |  |  |                     for (int xi = 0; xi < xC; ++xi) { | 
					
						
							|  |  |  |                         auto index = xIndex + xi; | 
					
						
							|  |  |  |                         auto srcXi = _dstOrigin + 4 * xi; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         int wIndex = index % wUnit; | 
					
						
							|  |  |  |                         int hIndex = index / wUnit; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         int dstX = wIndex * dstUnit; | 
					
						
							|  |  |  |                         int dstY = hIndex * dstUnit; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         auto dstStart = dstOrigin + 4 * (dstX + dstY * ow); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         int ey = ALIMIN(dstY + dstUnit, oh) - dstY; | 
					
						
							|  |  |  |                         int ex = ALIMIN(dstX + dstUnit, ow) - dstX; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         int count = ex * 4; | 
					
						
							|  |  |  |                         if (ex == dstUnit) { | 
					
						
							|  |  |  |                             for (int z = 0; z < dc_4; ++z) { | 
					
						
							|  |  |  |                                 auto dstZAddr = dstStart + z * dstZStep; | 
					
						
							|  |  |  |                                 auto srcZ     = srcXi + z * srcZStep; | 
					
						
							|  |  |  |                                 auto biasZ    = bias + 4 * z; | 
					
						
							|  |  |  |                                 // Transform
 | 
					
						
							|  |  |  |                                 for (int i = 0; i < srcUnit; ++i) { | 
					
						
							|  |  |  |                                     mDestTransform(srcZ + i * unitStep, midBuffer0 + i * dstUnit * 4, | 
					
						
							|  |  |  |                                                    srcUnit * unitStep, 4); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                                 for (int i = 0; i < ey; ++i) { | 
					
						
							|  |  |  |                                     auto dstAddr = dstZAddr + i * 4 * ow; | 
					
						
							|  |  |  |                                     mDestTransform(midBuffer0 + i * 4, dstAddr, 4 * dstUnit, 4); | 
					
						
							|  |  |  |                                     postFunction(dstAddr, biasZ, dstUnit, 1); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         } else { | 
					
						
							|  |  |  |                             for (int z = 0; z < dc_4; ++z) { | 
					
						
							|  |  |  |                                 auto dstZAddr = dstStart + z * dstZStep; | 
					
						
							|  |  |  |                                 auto srcZ     = srcXi + z * srcZStep; | 
					
						
							|  |  |  |                                 // Transform
 | 
					
						
							|  |  |  |                                 for (int i = 0; i < srcUnit; ++i) { | 
					
						
							|  |  |  |                                     mDestTransform(srcZ + i * unitStep, midBuffer0 + i * dstUnit * 4, | 
					
						
							|  |  |  |                                                    srcUnit * unitStep, 4); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                                 for (int i = 0; i < ey; ++i) { | 
					
						
							|  |  |  |                                     mDestTransform(midBuffer0 + i * 4, midBuffer1 + i * dstUnit * 4, 4 * dstUnit, 4); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                                 // PostTreat
 | 
					
						
							|  |  |  |                                 postFunction(midBuffer1, bias + 4 * z, dstUnit2, 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                                 for (int yy = 0; yy < ey; ++yy) { | 
					
						
							|  |  |  |                                     auto dstYAddr = dstZAddr + yy * 4 * ow; | 
					
						
							|  |  |  |                                     auto srcYAddr = midBuffer1 + yy * 4 * dstUnit; | 
					
						
							|  |  |  |                                     ::memcpy(dstYAddr, srcYAddr, count * sizeof(float)); | 
					
						
							|  |  |  |                                 } | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 /*Dest Transform And Post Treat End*/ | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-05-05 20:27:57 +08:00
										 |  |  |         }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         MNN_CONCURRENCY_BEGIN(tId, threadNumber) { | 
					
						
							|  |  |  |             tFunction((int)tId); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         MNN_CONCURRENCY_END(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | int ConvolutionWinograd::bestWinogradUnit(const Convolution2DCommon *common, const Tensor *inputTensor, | 
					
						
							|  |  |  |                                           const Tensor *outputTensor, int threadNumber) { | 
					
						
							|  |  |  |     int ow      = outputTensor->width(); | 
					
						
							|  |  |  |     int oh      = outputTensor->height(); | 
					
						
							|  |  |  |     int oc      = outputTensor->channel(); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |     int unit2   = UP_DIV(ow * oh, CONVOLUTION_TILED_NUMBER * threadNumber); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     int maxUnit = (int)::sqrtf((float)unit2); | 
					
						
							|  |  |  |     maxUnit     = std::min(maxUnit, CONVOLUTION_WINOGRAD_MAX_UNIT); | 
					
						
							|  |  |  |     maxUnit     = std::max(maxUnit, CONVOLUTION_WINOGRAD_MIN_UNIT); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int ic           = inputTensor->channel(); | 
					
						
							|  |  |  |     auto kernelSize  = common->kernelY(); | 
					
						
							|  |  |  |     int unit         = CONVOLUTION_WINOGRAD_MIN_UNIT; | 
					
						
							|  |  |  |     float maxRate    = 0.0f; | 
					
						
							|  |  |  |     float originCost = (float)ow * oh * (float)ic * oc * kernelSize * kernelSize; | 
					
						
							|  |  |  |     static std::set<int> supportSu{4, 8}; | 
					
						
							|  |  |  |     for (int u = CONVOLUTION_WINOGRAD_MIN_UNIT; u <= maxUnit; ++u) { | 
					
						
							|  |  |  |         float su = (float)(u + kernelSize - 1); | 
					
						
							|  |  |  |         if (supportSu.find(su) == supportSu.end()) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (nullptr == WinogradFunction::chooseDestTransform((int)su, u)) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         /*Let F(6,3) be choosed when it can speed up from F(2,3) than 0.6*/ | 
					
						
							|  |  |  |         float penalty = (su * su) / (float)(kernelSize * kernelSize) * 0.12f; | 
					
						
							|  |  |  |         float winogradCost = | 
					
						
							| 
									
										
										
										
											2020-03-25 17:41:58 +08:00
										 |  |  |             (2 * su * su * ic + su * su * ic * oc + (su + u) * u * oc) * (UP_DIV(ow, u) * UP_DIV(oh, u)); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         float reduceRate = originCost / winogradCost - penalty; | 
					
						
							|  |  |  |         // MNN_PRINT("ow=%d, oh=%d, %f, %f, winograd unit:%d\n", ow, oh, winogradCost, reduceRate, u);
 | 
					
						
							|  |  |  |         if (reduceRate > maxRate) { | 
					
						
							|  |  |  |             maxRate = reduceRate; | 
					
						
							|  |  |  |             unit    = u; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (maxRate < 1.0f) { | 
					
						
							|  |  |  |         return 0; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return unit; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | bool ConvolutionWinograd::canUseWinograd(const Convolution2DCommon *common) { | 
					
						
							|  |  |  |     if (common->kernelY() != common->kernelX() || common->kernelY() <= 1) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (common->dilateX() != 1 || common->dilateY() != 1) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (common->strideX() != 1 || common->strideY() != 1) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ErrorCode ConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | 
					
						
							|  |  |  |     CPUConvolution::onResize(inputs, outputs); | 
					
						
							|  |  |  |     // FUNC_PRINT(mA->length(1));
 | 
					
						
							|  |  |  |     bool success = backend()->onAcquireBuffer(&mTempBuffer, Backend::DYNAMIC); | 
					
						
							|  |  |  |     success      = success && (backend()->onAcquireBuffer(&mTransformMidBuffer, Backend::DYNAMIC)); | 
					
						
							|  |  |  |     backend()->onReleaseBuffer(&mTempBuffer, Backend::DYNAMIC); | 
					
						
							|  |  |  |     backend()->onReleaseBuffer(&mTransformMidBuffer, Backend::DYNAMIC); | 
					
						
							|  |  |  |     if (!success) { | 
					
						
							|  |  |  |         return OUT_OF_MEMORY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | } // namespace MNN
 |