mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			129 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			129 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  GeometrySpaceToBatchND.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2020/04/20.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include "geometry/GeometryComputer.hpp"
 | |
| #include "core/Macro.h"
 | |
| namespace MNN {
 | |
| class GeometrySpaceToBatchND : public GeometryComputer {
 | |
| public:
 | |
|     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
 | |
|                            Context& context, CommandBuffer& res) const override {
 | |
|         MNN_ASSERT(1 == outputs.size());
 | |
|         MNN_ASSERT(inputs.size() == 1 || inputs.size() == 3);
 | |
|         int blockSize = 0;
 | |
|         const int *blockData, *paddingData;
 | |
|         auto param            = op->main_as_SpaceBatch();
 | |
|         if (inputs.size() == 3) {
 | |
|             blockSize = inputs[1]->length(0);
 | |
|             blockData = inputs[1]->host<int32_t>();
 | |
|             paddingData = inputs[2]->host<int32_t>();
 | |
|         } else {
 | |
|             blockSize = param->blockShape()->dims()->data()[0];
 | |
|             blockData = param->blockShape()->int32s()->data();
 | |
|             paddingData = param->padding()->int32s()->data();
 | |
|         }
 | |
|         auto padTop           = paddingData[0];
 | |
|         auto padLeft          = 0;
 | |
|         auto blockShapeHeight = blockData[0];
 | |
|         auto blockShapeWidth  = 1;
 | |
|         if (blockSize > 1) {
 | |
|             padLeft         = paddingData[2];
 | |
|             blockShapeWidth = blockData[1];
 | |
|         }
 | |
|         auto input      = inputs[0];
 | |
|         auto output     = outputs[0];
 | |
|         auto outputDes  = TensorUtils::getDescribe(output);
 | |
|         auto realTensor = input;
 | |
|         // For OpType_BatchToSpaceND, swap input and output
 | |
|         if (op->type() == OpType_BatchToSpaceND) {
 | |
|             auto temp = output;
 | |
|             output    = input;
 | |
|             input     = temp;
 | |
|         }
 | |
| 
 | |
|         const int inHeight  = input->height();
 | |
|         const int inWidth   = input->width();
 | |
|         const int inBatch   = input->batch();
 | |
|         const int outHeight = output->height();
 | |
|         const int outWidth  = output->width();
 | |
|         const int outBatch  = output->batch();
 | |
|         auto regionSize     = outBatch / inBatch;
 | |
|         auto channel        = output->channel();
 | |
|         outputDes->regions.resize(regionSize);
 | |
|         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|         // NCHW stride
 | |
|         int inputStride[4];
 | |
|         int outputStride[4];
 | |
|         if (MNN_DATA_FORMAT_NHWC == outputDes->dimensionFormat) {
 | |
|             inputStride[0] = inWidth * inHeight * channel;
 | |
|             inputStride[1] = 1;
 | |
|             inputStride[2] = inWidth * channel;
 | |
|             inputStride[3] = channel;
 | |
| 
 | |
|             outputStride[0] = outWidth * outHeight * channel;
 | |
|             outputStride[1] = 1;
 | |
|             outputStride[2] = outWidth * channel;
 | |
|             outputStride[3] = channel;
 | |
|         } else {
 | |
|             inputStride[0] = inWidth * inHeight * channel;
 | |
|             inputStride[1] = inWidth * inHeight;
 | |
|             inputStride[2] = inWidth;
 | |
|             inputStride[3] = 1;
 | |
| 
 | |
|             outputStride[0] = outWidth * outHeight * channel;
 | |
|             outputStride[1] = outHeight * outWidth;
 | |
|             outputStride[2] = outWidth;
 | |
|             outputStride[3] = 1;
 | |
|         }
 | |
|         for (int r = 0; r < regionSize; ++r) {
 | |
|             auto& region  = outputDes->regions[r];
 | |
|             region.origin = realTensor;
 | |
|             int strideW   = r % blockShapeWidth;
 | |
|             int strideH   = r / blockShapeWidth;
 | |
| 
 | |
|             const int validHStart = ALIMAX(0, (padTop - strideH + blockShapeHeight - 1) / blockShapeHeight);
 | |
|             const int validHEnd =
 | |
|                 ALIMIN(outHeight, (inHeight + padTop - strideH + blockShapeHeight - 1) / blockShapeHeight);
 | |
|             const int validWStart = ALIMAX(0, (padLeft - strideW + blockShapeWidth - 1) / blockShapeWidth);
 | |
|             const int validWEnd =
 | |
|                 ALIMIN(outWidth, (inWidth + padLeft - strideW + blockShapeWidth - 1) / blockShapeWidth);
 | |
|             int inHeightStart = validHStart * blockShapeHeight + strideH - padTop;
 | |
|             int inWidthStart  = validHStart * blockShapeWidth + strideW - padLeft;
 | |
|             auto srcR         = ®ion.src;
 | |
|             auto dstR         = ®ion.dst;
 | |
|             if (op->type() == OpType_BatchToSpaceND) {
 | |
|                 srcR = ®ion.dst;
 | |
|                 dstR = ®ion.src;
 | |
|             }
 | |
|             srcR->offset    = inHeightStart * inputStride[2] + inWidthStart * inputStride[3];
 | |
|             srcR->stride[0] = 1 * inputStride[1];
 | |
|             srcR->stride[1] = blockShapeHeight * inputStride[2];
 | |
|             srcR->stride[2] = blockShapeWidth * inputStride[3];
 | |
| 
 | |
|             region.size[0] = inBatch * channel;
 | |
|             region.size[1] = validHEnd - validHStart;
 | |
|             region.size[2] = validWEnd - validWStart;
 | |
| 
 | |
|             dstR->offset =
 | |
|                 outputStride[2] * validHStart + outputStride[3] * validWStart + r * inBatch * outputStride[0];
 | |
|             dstR->stride[0] = outputStride[1];
 | |
|             dstR->stride[1] = outputStride[2];
 | |
|             dstR->stride[2] = outputStride[3];
 | |
|         }
 | |
|         return true;
 | |
|     }
 | |
| };
 | |
| static void _create() {
 | |
|     std::shared_ptr<GeometryComputer> comp(new GeometrySpaceToBatchND);
 | |
|     GeometryComputer::registerGeometryComputer(comp, {OpType_SpaceToBatchND, OpType_BatchToSpaceND});
 | |
| }
 | |
| 
 | |
| REGISTER_GEOMETRY(GeometrySpaceToBatchND, _create);
 | |
| 
 | |
| } // namespace MNN
 |