mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			129 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			129 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  GeometrySpaceToBatchND.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2020/04/20.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#include "geometry/GeometryComputer.hpp"
 | 
						|
#include "core/Macro.h"
 | 
						|
namespace MNN {
 | 
						|
class GeometrySpaceToBatchND : public GeometryComputer {
 | 
						|
public:
 | 
						|
    virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
 | 
						|
                           Context& context, CommandBuffer& res) const override {
 | 
						|
        MNN_ASSERT(1 == outputs.size());
 | 
						|
        MNN_ASSERT(inputs.size() == 1 || inputs.size() == 3);
 | 
						|
        int blockSize = 0;
 | 
						|
        const int *blockData, *paddingData;
 | 
						|
        auto param            = op->main_as_SpaceBatch();
 | 
						|
        if (inputs.size() == 3) {
 | 
						|
            blockSize = inputs[1]->length(0);
 | 
						|
            blockData = inputs[1]->host<int32_t>();
 | 
						|
            paddingData = inputs[2]->host<int32_t>();
 | 
						|
        } else {
 | 
						|
            blockSize = param->blockShape()->dims()->data()[0];
 | 
						|
            blockData = param->blockShape()->int32s()->data();
 | 
						|
            paddingData = param->padding()->int32s()->data();
 | 
						|
        }
 | 
						|
        auto padTop           = paddingData[0];
 | 
						|
        auto padLeft          = 0;
 | 
						|
        auto blockShapeHeight = blockData[0];
 | 
						|
        auto blockShapeWidth  = 1;
 | 
						|
        if (blockSize > 1) {
 | 
						|
            padLeft         = paddingData[2];
 | 
						|
            blockShapeWidth = blockData[1];
 | 
						|
        }
 | 
						|
        auto input      = inputs[0];
 | 
						|
        auto output     = outputs[0];
 | 
						|
        auto outputDes  = TensorUtils::getDescribe(output);
 | 
						|
        auto realTensor = input;
 | 
						|
        // For OpType_BatchToSpaceND, swap input and output
 | 
						|
        if (op->type() == OpType_BatchToSpaceND) {
 | 
						|
            auto temp = output;
 | 
						|
            output    = input;
 | 
						|
            input     = temp;
 | 
						|
        }
 | 
						|
 | 
						|
        const int inHeight  = input->height();
 | 
						|
        const int inWidth   = input->width();
 | 
						|
        const int inBatch   = input->batch();
 | 
						|
        const int outHeight = output->height();
 | 
						|
        const int outWidth  = output->width();
 | 
						|
        const int outBatch  = output->batch();
 | 
						|
        auto regionSize     = outBatch / inBatch;
 | 
						|
        auto channel        = output->channel();
 | 
						|
        outputDes->regions.resize(regionSize);
 | 
						|
        outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | 
						|
        // NCHW stride
 | 
						|
        int inputStride[4];
 | 
						|
        int outputStride[4];
 | 
						|
        if (MNN_DATA_FORMAT_NHWC == outputDes->dimensionFormat) {
 | 
						|
            inputStride[0] = inWidth * inHeight * channel;
 | 
						|
            inputStride[1] = 1;
 | 
						|
            inputStride[2] = inWidth * channel;
 | 
						|
            inputStride[3] = channel;
 | 
						|
 | 
						|
            outputStride[0] = outWidth * outHeight * channel;
 | 
						|
            outputStride[1] = 1;
 | 
						|
            outputStride[2] = outWidth * channel;
 | 
						|
            outputStride[3] = channel;
 | 
						|
        } else {
 | 
						|
            inputStride[0] = inWidth * inHeight * channel;
 | 
						|
            inputStride[1] = inWidth * inHeight;
 | 
						|
            inputStride[2] = inWidth;
 | 
						|
            inputStride[3] = 1;
 | 
						|
 | 
						|
            outputStride[0] = outWidth * outHeight * channel;
 | 
						|
            outputStride[1] = outHeight * outWidth;
 | 
						|
            outputStride[2] = outWidth;
 | 
						|
            outputStride[3] = 1;
 | 
						|
        }
 | 
						|
        for (int r = 0; r < regionSize; ++r) {
 | 
						|
            auto& region  = outputDes->regions[r];
 | 
						|
            region.origin = realTensor;
 | 
						|
            int strideW   = r % blockShapeWidth;
 | 
						|
            int strideH   = r / blockShapeWidth;
 | 
						|
 | 
						|
            const int validHStart = ALIMAX(0, (padTop - strideH + blockShapeHeight - 1) / blockShapeHeight);
 | 
						|
            const int validHEnd =
 | 
						|
                ALIMIN(outHeight, (inHeight + padTop - strideH + blockShapeHeight - 1) / blockShapeHeight);
 | 
						|
            const int validWStart = ALIMAX(0, (padLeft - strideW + blockShapeWidth - 1) / blockShapeWidth);
 | 
						|
            const int validWEnd =
 | 
						|
                ALIMIN(outWidth, (inWidth + padLeft - strideW + blockShapeWidth - 1) / blockShapeWidth);
 | 
						|
            int inHeightStart = validHStart * blockShapeHeight + strideH - padTop;
 | 
						|
            int inWidthStart  = validHStart * blockShapeWidth + strideW - padLeft;
 | 
						|
            auto srcR         = ®ion.src;
 | 
						|
            auto dstR         = ®ion.dst;
 | 
						|
            if (op->type() == OpType_BatchToSpaceND) {
 | 
						|
                srcR = ®ion.dst;
 | 
						|
                dstR = ®ion.src;
 | 
						|
            }
 | 
						|
            srcR->offset    = inHeightStart * inputStride[2] + inWidthStart * inputStride[3];
 | 
						|
            srcR->stride[0] = 1 * inputStride[1];
 | 
						|
            srcR->stride[1] = blockShapeHeight * inputStride[2];
 | 
						|
            srcR->stride[2] = blockShapeWidth * inputStride[3];
 | 
						|
 | 
						|
            region.size[0] = inBatch * channel;
 | 
						|
            region.size[1] = validHEnd - validHStart;
 | 
						|
            region.size[2] = validWEnd - validWStart;
 | 
						|
 | 
						|
            dstR->offset =
 | 
						|
                outputStride[2] * validHStart + outputStride[3] * validWStart + r * inBatch * outputStride[0];
 | 
						|
            dstR->stride[0] = outputStride[1];
 | 
						|
            dstR->stride[1] = outputStride[2];
 | 
						|
            dstR->stride[2] = outputStride[3];
 | 
						|
        }
 | 
						|
        return true;
 | 
						|
    }
 | 
						|
};
 | 
						|
static void _create() {
 | 
						|
    std::shared_ptr<GeometryComputer> comp(new GeometrySpaceToBatchND);
 | 
						|
    GeometryComputer::registerGeometryComputer(comp, {OpType_SpaceToBatchND, OpType_BatchToSpaceND});
 | 
						|
}
 | 
						|
 | 
						|
REGISTER_GEOMETRY(GeometrySpaceToBatchND, _create);
 | 
						|
 | 
						|
} // namespace MNN
 |