mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			85 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			85 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  GeometryReshape.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2020/04/03.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include "ConvertUtils.hpp"
 | ||
|  | #include "geometry/GeometryComputer.hpp"
 | ||
|  | namespace MNN { | ||
|  | class GeometryReshape : public GeometryComputer { | ||
|  | public: | ||
|  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | ||
|  |                            Context& context, CommandBuffer& res) const override { | ||
|  |         auto input     = inputs[0]; | ||
|  |         auto output    = outputs[0]; | ||
|  |         auto inputDes  = TensorUtils::getDescribe(input); | ||
|  |         auto outputDes = TensorUtils::getDescribe(output); | ||
|  |         if (TensorUtils::getDescribe(input)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { | ||
|  |             auto midFormat = op->main_as_Reshape()->dimType(); | ||
|  |             if (MNN_DATA_FORMAT_NHWC == midFormat) { | ||
|  |                 // Convert to NHWC, reshape, and then convert to NC4HW4
 | ||
|  |                 std::shared_ptr<Tensor> nhwc(new Tensor); | ||
|  |                 TensorUtils::setupTensorInfo(input, nhwc.get(), MNN_DATA_FORMAT_NHWC); | ||
|  |                 ConvertUtils::compute(input, nhwc.get(), res); | ||
|  |                 res.extras.emplace_back(nhwc); | ||
|  |                 std::shared_ptr<Tensor> nhwc2(new Tensor); | ||
|  |                 TensorUtils::setupTensorInfo(output, nhwc2.get(), MNN_DATA_FORMAT_NHWC); | ||
|  |                 res.extras.emplace_back(nhwc2); | ||
|  |                 { | ||
|  |                     auto inputSlice = TensorUtils::getDescribe(nhwc.get())->regions; | ||
|  |                     if (inputSlice.empty()) { | ||
|  |                         // Create Full Refence
 | ||
|  |                         Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(nhwc.get()); | ||
|  |                         inputSlice.emplace_back(std::move(totalSlice)); | ||
|  |                     } | ||
|  |                     TensorUtils::getDescribe(nhwc2.get())->regions    = std::move(inputSlice); | ||
|  |                     TensorUtils::getDescribe(nhwc2.get())->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |                 } | ||
|  |                 ConvertUtils::compute(nhwc2.get(), output, res); | ||
|  |                 return true; | ||
|  |             } | ||
|  |         } | ||
|  |         auto inputSlice = inputDes->regions; | ||
|  |         if (inputSlice.empty()) { | ||
|  |             // Create Full Refence
 | ||
|  |             Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input); | ||
|  |             inputSlice.emplace_back(std::move(totalSlice)); | ||
|  |         } | ||
|  |         outputDes->regions    = std::move(inputSlice); | ||
|  |         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |         return true; | ||
|  |     } | ||
|  | }; | ||
|  | class SingleGeometryComputer : public GeometryComputer { | ||
|  | public: | ||
|  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | ||
|  |                            Context& context, CommandBuffer& res) const override { | ||
|  |         auto input      = inputs[0]; | ||
|  |         auto output     = outputs[0]; | ||
|  |         auto inputDes   = TensorUtils::getDescribe(input); | ||
|  |         auto outputDes  = TensorUtils::getDescribe(output); | ||
|  |         auto inputSlice = inputDes->regions; | ||
|  |         if (inputSlice.empty()) { | ||
|  |             // Create Full Refence
 | ||
|  |             Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input); | ||
|  |             inputSlice.emplace_back(std::move(totalSlice)); | ||
|  |         } | ||
|  |         outputDes->regions    = std::move(inputSlice); | ||
|  |         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |         return true; | ||
|  |     } | ||
|  | }; | ||
|  | 
 | ||
|  | static void _create() { | ||
|  |     std::shared_ptr<GeometryComputer> comp(new GeometryReshape); | ||
|  |     GeometryComputer::registerGeometryComputer(comp, {OpType_Reshape}); | ||
|  |     std::shared_ptr<GeometryComputer> _comp(new SingleGeometryComputer); | ||
|  |     GeometryComputer::registerGeometryComputer(_comp, {OpType_Squeeze, OpType_Unsqueeze, OpType_ExpandDims, OpType_Flatten, OpType_QuantizedReshape}); | ||
|  | } | ||
|  | 
 | ||
|  | REGISTER_GEOMETRY(GeometryReshape, _create); | ||
|  | }; // namespace MNN
 |