mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			117 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			117 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  GeometryDepthToSpace.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2020/04/23.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#include "geometry/GeometryComputer.hpp"
 | 
						|
#include "core/Macro.h"
 | 
						|
namespace MNN {
 | 
						|
class GeometryDepthToSpace : public GeometryComputer {
 | 
						|
public:
 | 
						|
    virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
 | 
						|
                           Context& context, CommandBuffer& res) const override {
 | 
						|
        MNN_ASSERT(1 == outputs.size());
 | 
						|
        MNN_ASSERT(inputs.size() == 1);
 | 
						|
        const int blockSize   = op->main_as_DepthSpaceParam()->blockSize();
 | 
						|
        auto mode = op->main_as_DepthSpaceParam()->mode();
 | 
						|
        auto input            = inputs[0];
 | 
						|
        auto output           = outputs[0];
 | 
						|
        auto outputDes        = TensorUtils::getDescribe(output);
 | 
						|
        outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | 
						|
        auto realTensor       = input;
 | 
						|
        // For OpType_SpaceToDepth, swap input and output
 | 
						|
        if (op->type() == OpType_SpaceToDepth) {
 | 
						|
            auto temp = output;
 | 
						|
            output    = input;
 | 
						|
            input     = temp;
 | 
						|
        }
 | 
						|
 | 
						|
        const int inHeight   = input->height();
 | 
						|
        const int inWidth    = input->width();
 | 
						|
        const int inChannel  = input->channel();
 | 
						|
        const int outHeight  = output->height();
 | 
						|
        const int outWidth   = output->width();
 | 
						|
        const int outChannel = output->channel();
 | 
						|
        // NCHW Stride
 | 
						|
        int inputStride[4];
 | 
						|
        int outputStride[4];
 | 
						|
        if (MNN_DATA_FORMAT_NHWC == outputDes->dimensionFormat) {
 | 
						|
            inputStride[0] = inWidth * inHeight * inChannel;
 | 
						|
            inputStride[1] = 1;
 | 
						|
            inputStride[2] = inWidth * inChannel;
 | 
						|
            inputStride[3] = inChannel;
 | 
						|
 | 
						|
            outputStride[0] = outWidth * outHeight * outChannel;
 | 
						|
            outputStride[1] = 1;
 | 
						|
            outputStride[2] = outWidth * outChannel;
 | 
						|
            outputStride[3] = outChannel;
 | 
						|
        } else {
 | 
						|
            inputStride[0] = inWidth * inHeight * inChannel;
 | 
						|
            inputStride[1] = inWidth * inHeight;
 | 
						|
            inputStride[2] = inWidth;
 | 
						|
            inputStride[3] = 1;
 | 
						|
 | 
						|
            outputStride[0] = outWidth * outHeight * outChannel;
 | 
						|
            outputStride[1] = outHeight * outWidth;
 | 
						|
            outputStride[2] = outWidth;
 | 
						|
            outputStride[3] = 1;
 | 
						|
        }
 | 
						|
        auto batch      = input->batch();
 | 
						|
        auto regionSize = blockSize * blockSize * batch;
 | 
						|
        outputDes->regions.resize(regionSize);
 | 
						|
        for (int b = 0; b < batch; ++b) {
 | 
						|
            auto dstB = b * outputStride[0];
 | 
						|
            auto srcB = b * inputStride[0];
 | 
						|
            for (int hb = 0; hb < blockSize; ++hb) {
 | 
						|
                auto dstHB = dstB + hb * outputStride[2];
 | 
						|
                for (int wb = 0; wb < blockSize; ++wb) {
 | 
						|
                    auto dstWB        = dstHB + wb * outputStride[3];
 | 
						|
                    int offsetC = hb * blockSize + wb;
 | 
						|
                    if (mode == DepthToSpaceMode_DCR) {
 | 
						|
                        offsetC *= outChannel;
 | 
						|
                    }
 | 
						|
                    auto srcWB        = srcB + offsetC * inputStride[1];
 | 
						|
 | 
						|
                    auto& region   = outputDes->regions[b * blockSize * blockSize + wb + hb * blockSize];
 | 
						|
                    region.origin  = realTensor;
 | 
						|
                    region.size[0] = inHeight;
 | 
						|
                    region.size[1] = inWidth;
 | 
						|
                    region.size[2] = outChannel;
 | 
						|
 | 
						|
                    auto srcR = ®ion.src;
 | 
						|
                    auto dstR = ®ion.dst;
 | 
						|
                    if (op->type() == OpType_SpaceToDepth) {
 | 
						|
                        srcR = ®ion.dst;
 | 
						|
                        dstR = ®ion.src;
 | 
						|
                    }
 | 
						|
 | 
						|
                    dstR->offset    = dstWB;
 | 
						|
                    dstR->stride[0] = outputStride[2] * blockSize;
 | 
						|
                    dstR->stride[1] = outputStride[3] * blockSize;
 | 
						|
                    dstR->stride[2] = outputStride[1];
 | 
						|
 | 
						|
                    srcR->offset    = srcWB;
 | 
						|
                    srcR->stride[0] = inputStride[2];
 | 
						|
                    srcR->stride[1] = inputStride[3];
 | 
						|
                    srcR->stride[2] = inputStride[1];
 | 
						|
                    if (mode == DepthToSpaceMode_CRD) {
 | 
						|
                        srcR->stride[2] *= (blockSize * blockSize);
 | 
						|
                    }
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
        return true;
 | 
						|
    }
 | 
						|
};
 | 
						|
static void _create() {
 | 
						|
    std::shared_ptr<GeometryComputer> comp(new GeometryDepthToSpace);
 | 
						|
    GeometryComputer::registerGeometryComputer(comp, {OpType_DepthToSpace, OpType_SpaceToDepth});
 | 
						|
}
 | 
						|
 | 
						|
REGISTER_GEOMETRY(GeometryDepthToSpace, _create);
 | 
						|
 | 
						|
} // namespace MNN
 |