mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			129 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			129 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  GeometrySpaceToBatchND.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2020/04/20.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include "geometry/GeometryComputer.hpp"
 | ||
|  | #include "core/Macro.h"
 | ||
|  | namespace MNN { | ||
|  | class GeometrySpaceToBatchND : public GeometryComputer { | ||
|  | public: | ||
|  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | ||
|  |                            Context& context, CommandBuffer& res) const override { | ||
|  |         MNN_ASSERT(1 == outputs.size()); | ||
|  |         MNN_ASSERT(inputs.size() == 1 || inputs.size() == 3); | ||
|  |         int blockSize = 0; | ||
|  |         const int *blockData, *paddingData; | ||
|  |         auto param            = op->main_as_SpaceBatch(); | ||
|  |         if (inputs.size() == 3) { | ||
|  |             blockSize = inputs[1]->length(0); | ||
|  |             blockData = inputs[1]->host<int32_t>(); | ||
|  |             paddingData = inputs[2]->host<int32_t>(); | ||
|  |         } else { | ||
|  |             blockSize = param->blockShape()->dims()->data()[0]; | ||
|  |             blockData = param->blockShape()->int32s()->data(); | ||
|  |             paddingData = param->padding()->int32s()->data(); | ||
|  |         } | ||
|  |         auto padTop           = paddingData[0]; | ||
|  |         auto padLeft          = 0; | ||
|  |         auto blockShapeHeight = blockData[0]; | ||
|  |         auto blockShapeWidth  = 1; | ||
|  |         if (blockSize > 1) { | ||
|  |             padLeft         = paddingData[2]; | ||
|  |             blockShapeWidth = blockData[1]; | ||
|  |         } | ||
|  |         auto input      = inputs[0]; | ||
|  |         auto output     = outputs[0]; | ||
|  |         auto outputDes  = TensorUtils::getDescribe(output); | ||
|  |         auto realTensor = input; | ||
|  |         // For OpType_BatchToSpaceND, swap input and output
 | ||
|  |         if (op->type() == OpType_BatchToSpaceND) { | ||
|  |             auto temp = output; | ||
|  |             output    = input; | ||
|  |             input     = temp; | ||
|  |         } | ||
|  | 
 | ||
|  |         const int inHeight  = input->height(); | ||
|  |         const int inWidth   = input->width(); | ||
|  |         const int inBatch   = input->batch(); | ||
|  |         const int outHeight = output->height(); | ||
|  |         const int outWidth  = output->width(); | ||
|  |         const int outBatch  = output->batch(); | ||
|  |         auto regionSize     = outBatch / inBatch; | ||
|  |         auto channel        = output->channel(); | ||
|  |         outputDes->regions.resize(regionSize); | ||
|  |         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |         // NCHW stride
 | ||
|  |         int inputStride[4]; | ||
|  |         int outputStride[4]; | ||
|  |         if (MNN_DATA_FORMAT_NHWC == outputDes->dimensionFormat) { | ||
|  |             inputStride[0] = inWidth * inHeight * channel; | ||
|  |             inputStride[1] = 1; | ||
|  |             inputStride[2] = inWidth * channel; | ||
|  |             inputStride[3] = channel; | ||
|  | 
 | ||
|  |             outputStride[0] = outWidth * outHeight * channel; | ||
|  |             outputStride[1] = 1; | ||
|  |             outputStride[2] = outWidth * channel; | ||
|  |             outputStride[3] = channel; | ||
|  |         } else { | ||
|  |             inputStride[0] = inWidth * inHeight * channel; | ||
|  |             inputStride[1] = inWidth * inHeight; | ||
|  |             inputStride[2] = inWidth; | ||
|  |             inputStride[3] = 1; | ||
|  | 
 | ||
|  |             outputStride[0] = outWidth * outHeight * channel; | ||
|  |             outputStride[1] = outHeight * outWidth; | ||
|  |             outputStride[2] = outWidth; | ||
|  |             outputStride[3] = 1; | ||
|  |         } | ||
|  |         for (int r = 0; r < regionSize; ++r) { | ||
|  |             auto& region  = outputDes->regions[r]; | ||
|  |             region.origin = realTensor; | ||
|  |             int strideW   = r % blockShapeWidth; | ||
|  |             int strideH   = r / blockShapeWidth; | ||
|  | 
 | ||
|  |             const int validHStart = ALIMAX(0, (padTop - strideH + blockShapeHeight - 1) / blockShapeHeight); | ||
|  |             const int validHEnd = | ||
|  |                 ALIMIN(outHeight, (inHeight + padTop - strideH + blockShapeHeight - 1) / blockShapeHeight); | ||
|  |             const int validWStart = ALIMAX(0, (padLeft - strideW + blockShapeWidth - 1) / blockShapeWidth); | ||
|  |             const int validWEnd = | ||
|  |                 ALIMIN(outWidth, (inWidth + padLeft - strideW + blockShapeWidth - 1) / blockShapeWidth); | ||
|  |             int inHeightStart = validHStart * blockShapeHeight + strideH - padTop; | ||
|  |             int inWidthStart  = validHStart * blockShapeWidth + strideW - padLeft; | ||
|  |             auto srcR         = ®ion.src; | ||
|  |             auto dstR         = ®ion.dst; | ||
|  |             if (op->type() == OpType_BatchToSpaceND) { | ||
|  |                 srcR = ®ion.dst; | ||
|  |                 dstR = ®ion.src; | ||
|  |             } | ||
|  |             srcR->offset    = inHeightStart * inputStride[2] + inWidthStart * inputStride[3]; | ||
|  |             srcR->stride[0] = 1 * inputStride[1]; | ||
|  |             srcR->stride[1] = blockShapeHeight * inputStride[2]; | ||
|  |             srcR->stride[2] = blockShapeWidth * inputStride[3]; | ||
|  | 
 | ||
|  |             region.size[0] = inBatch * channel; | ||
|  |             region.size[1] = validHEnd - validHStart; | ||
|  |             region.size[2] = validWEnd - validWStart; | ||
|  | 
 | ||
|  |             dstR->offset = | ||
|  |                 outputStride[2] * validHStart + outputStride[3] * validWStart + r * inBatch * outputStride[0]; | ||
|  |             dstR->stride[0] = outputStride[1]; | ||
|  |             dstR->stride[1] = outputStride[2]; | ||
|  |             dstR->stride[2] = outputStride[3]; | ||
|  |         } | ||
|  |         return true; | ||
|  |     } | ||
|  | }; | ||
|  | static void _create() { | ||
|  |     std::shared_ptr<GeometryComputer> comp(new GeometrySpaceToBatchND); | ||
|  |     GeometryComputer::registerGeometryComputer(comp, {OpType_SpaceToBatchND, OpType_BatchToSpaceND}); | ||
|  | } | ||
|  | 
 | ||
|  | REGISTER_GEOMETRY(GeometrySpaceToBatchND, _create); | ||
|  | 
 | ||
|  | } // namespace MNN
 |