mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			117 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			117 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  GeometryDepthToSpace.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2020/04/23.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include "geometry/GeometryComputer.hpp"
 | |
| #include "core/Macro.h"
 | |
| namespace MNN {
 | |
| class GeometryDepthToSpace : public GeometryComputer {
 | |
| public:
 | |
|     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
 | |
|                            Context& context, CommandBuffer& res) const override {
 | |
|         MNN_ASSERT(1 == outputs.size());
 | |
|         MNN_ASSERT(inputs.size() == 1);
 | |
|         const int blockSize   = op->main_as_DepthSpaceParam()->blockSize();
 | |
|         auto mode = op->main_as_DepthSpaceParam()->mode();
 | |
|         auto input            = inputs[0];
 | |
|         auto output           = outputs[0];
 | |
|         auto outputDes        = TensorUtils::getDescribe(output);
 | |
|         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|         auto realTensor       = input;
 | |
|         // For OpType_SpaceToDepth, swap input and output
 | |
|         if (op->type() == OpType_SpaceToDepth) {
 | |
|             auto temp = output;
 | |
|             output    = input;
 | |
|             input     = temp;
 | |
|         }
 | |
| 
 | |
|         const int inHeight   = input->height();
 | |
|         const int inWidth    = input->width();
 | |
|         const int inChannel  = input->channel();
 | |
|         const int outHeight  = output->height();
 | |
|         const int outWidth   = output->width();
 | |
|         const int outChannel = output->channel();
 | |
|         // NCHW Stride
 | |
|         int inputStride[4];
 | |
|         int outputStride[4];
 | |
|         if (MNN_DATA_FORMAT_NHWC == outputDes->dimensionFormat) {
 | |
|             inputStride[0] = inWidth * inHeight * inChannel;
 | |
|             inputStride[1] = 1;
 | |
|             inputStride[2] = inWidth * inChannel;
 | |
|             inputStride[3] = inChannel;
 | |
| 
 | |
|             outputStride[0] = outWidth * outHeight * outChannel;
 | |
|             outputStride[1] = 1;
 | |
|             outputStride[2] = outWidth * outChannel;
 | |
|             outputStride[3] = outChannel;
 | |
|         } else {
 | |
|             inputStride[0] = inWidth * inHeight * inChannel;
 | |
|             inputStride[1] = inWidth * inHeight;
 | |
|             inputStride[2] = inWidth;
 | |
|             inputStride[3] = 1;
 | |
| 
 | |
|             outputStride[0] = outWidth * outHeight * outChannel;
 | |
|             outputStride[1] = outHeight * outWidth;
 | |
|             outputStride[2] = outWidth;
 | |
|             outputStride[3] = 1;
 | |
|         }
 | |
|         auto batch      = input->batch();
 | |
|         auto regionSize = blockSize * blockSize * batch;
 | |
|         outputDes->regions.resize(regionSize);
 | |
|         for (int b = 0; b < batch; ++b) {
 | |
|             auto dstB = b * outputStride[0];
 | |
|             auto srcB = b * inputStride[0];
 | |
|             for (int hb = 0; hb < blockSize; ++hb) {
 | |
|                 auto dstHB = dstB + hb * outputStride[2];
 | |
|                 for (int wb = 0; wb < blockSize; ++wb) {
 | |
|                     auto dstWB        = dstHB + wb * outputStride[3];
 | |
|                     int offsetC = hb * blockSize + wb;
 | |
|                     if (mode == DepthToSpaceMode_DCR) {
 | |
|                         offsetC *= outChannel;
 | |
|                     }
 | |
|                     auto srcWB        = srcB + offsetC * inputStride[1];
 | |
| 
 | |
|                     auto& region   = outputDes->regions[b * blockSize * blockSize + wb + hb * blockSize];
 | |
|                     region.origin  = realTensor;
 | |
|                     region.size[0] = inHeight;
 | |
|                     region.size[1] = inWidth;
 | |
|                     region.size[2] = outChannel;
 | |
| 
 | |
|                     auto srcR = ®ion.src;
 | |
|                     auto dstR = ®ion.dst;
 | |
|                     if (op->type() == OpType_SpaceToDepth) {
 | |
|                         srcR = ®ion.dst;
 | |
|                         dstR = ®ion.src;
 | |
|                     }
 | |
| 
 | |
|                     dstR->offset    = dstWB;
 | |
|                     dstR->stride[0] = outputStride[2] * blockSize;
 | |
|                     dstR->stride[1] = outputStride[3] * blockSize;
 | |
|                     dstR->stride[2] = outputStride[1];
 | |
| 
 | |
|                     srcR->offset    = srcWB;
 | |
|                     srcR->stride[0] = inputStride[2];
 | |
|                     srcR->stride[1] = inputStride[3];
 | |
|                     srcR->stride[2] = inputStride[1];
 | |
|                     if (mode == DepthToSpaceMode_CRD) {
 | |
|                         srcR->stride[2] *= (blockSize * blockSize);
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         return true;
 | |
|     }
 | |
| };
 | |
| static void _create() {
 | |
|     std::shared_ptr<GeometryComputer> comp(new GeometryDepthToSpace);
 | |
|     GeometryComputer::registerGeometryComputer(comp, {OpType_DepthToSpace, OpType_SpaceToDepth});
 | |
| }
 | |
| 
 | |
| REGISTER_GEOMETRY(GeometryDepthToSpace, _create);
 | |
| 
 | |
| } // namespace MNN
 |