mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			154 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			154 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  GeometrySpatialProduct.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2020/07/12.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include "geometry/GeometryComputer.hpp"
 | |
| #include "core/OpCommonUtils.hpp"
 | |
| #include "geometry/GeometryComputerUtils.hpp"
 | |
| 
 | |
| namespace MNN {
 | |
| class GeometrySpatialProduct : public GeometryComputer {
 | |
| public:
 | |
|     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
 | |
|                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
 | |
|         // Assume
 | |
|         // bottom[0] dim CxHxW
 | |
|         // bottom[1] dim 1xHxW
 | |
|         // top[0]    dim CxHxW
 | |
|         MNN_ASSERT(2 == inputs.size());
 | |
|         MNN_ASSERT(1 == outputs.size());
 | |
| 
 | |
|         auto input     = inputs[0];
 | |
|         auto input1    = inputs[1];
 | |
|         auto output    = outputs[0];
 | |
|         
 | |
|         int ib      = input->batch();
 | |
|         int iw      = input->width();
 | |
|         int ih      = input->height();
 | |
|         int ic      = input->channel();
 | |
|         
 | |
|         auto ob = output->batch();
 | |
|         auto oc = output->channel();
 | |
|         auto oh = output->height();
 | |
|         auto ow = output->width();
 | |
|         auto inside = iw*ih;
 | |
|         
 | |
|         //input transform to NCHW format
 | |
|         std::shared_ptr<Tensor> tmpInput;
 | |
|         {
 | |
|             tmpInput.reset(new Tensor);
 | |
|             tmpInput->buffer().type = halide_type_of<float>();
 | |
|             tmpInput->buffer().dimensions = 4;
 | |
|             tmpInput->setLength(0, ib);
 | |
|             tmpInput->setLength(1, ic);
 | |
|             tmpInput->setLength(2, ih);
 | |
|             tmpInput->setLength(3, iw);
 | |
|             auto outputDes = TensorUtils::getDescribe(tmpInput.get());
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW;
 | |
| 
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = ib;
 | |
|             desReg.size[1] = ic;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = ic*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = ic*inside;
 | |
|             desReg.src.stride[1] = inside;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = input;
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
|             
 | |
|             res.extras.emplace_back(tmpInput);
 | |
|         }
 | |
|         
 | |
|         //input1 broadcast to NCHW format
 | |
|         std::shared_ptr<Tensor> tmpInput1;
 | |
|         {
 | |
|             tmpInput1.reset(new Tensor);
 | |
|             tmpInput1->buffer().type = halide_type_of<float>();
 | |
|             tmpInput1->buffer().dimensions = 4;
 | |
|             tmpInput1->setLength(0, ib);
 | |
|             tmpInput1->setLength(1, ic);
 | |
|             tmpInput1->setLength(2, ih);
 | |
|             tmpInput1->setLength(3, iw);
 | |
|             auto outputDes = TensorUtils::getDescribe(tmpInput1.get());
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW;
 | |
| 
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = ib;
 | |
|             desReg.size[1] = ic;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = ic*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = inside;
 | |
|             desReg.src.stride[1] = 0;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = input1;
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
|             
 | |
|             res.extras.emplace_back(tmpInput1);
 | |
|         }
 | |
|         
 | |
|         std::shared_ptr<Tensor> tmpOutput;
 | |
|         {
 | |
|             tmpOutput.reset(new Tensor);
 | |
|             tmpOutput->buffer().type = halide_type_of<float>();
 | |
|             tmpOutput->buffer().dimensions = 4;
 | |
|             tmpOutput->setLength(0, ob);
 | |
|             tmpOutput->setLength(1, oc);
 | |
|             tmpOutput->setLength(2, oh);
 | |
|             tmpOutput->setLength(3, ow);
 | |
|             auto des = TensorUtils::getDescribe(tmpOutput.get());
 | |
|             des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
 | |
|             
 | |
|             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput.get(), tmpInput1.get(), tmpOutput.get());
 | |
|         
 | |
|             res.extras.emplace_back(tmpOutput);
 | |
|             res.command.emplace_back(std::move(cmd));
 | |
|         }
 | |
|         
 | |
|         
 | |
|         //transform to output
 | |
|         {
 | |
|             auto outputDes = TensorUtils::getDescribe(output);
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = ob;
 | |
|             desReg.size[1] = oc;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = oc*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = oc*inside;
 | |
|             desReg.src.stride[1] = inside;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = tmpOutput.get();
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
|         }
 | |
|         return true;
 | |
|     }
 | |
| };
 | |
| 
 | |
| static void _create() {
 | |
|     std::shared_ptr<GeometryComputer> comp(new GeometrySpatialProduct);
 | |
|     GeometryComputer::registerGeometryComputer(comp, {OpType_SpatialProduct});
 | |
| }
 | |
| 
 | |
| REGISTER_GEOMETRY(GeometrySpatialProduct, _create);
 | |
| 
 | |
| } // namespace MNN
 |