mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			67 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			67 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  GeometryBinary.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2020/05/07.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include "ConvertUtils.hpp"
 | ||
|  | #include "geometry/GeometryComputer.hpp"
 | ||
|  | #include "shape/SizeComputer.hpp"
 | ||
|  | namespace MNN { | ||
|  | class GeometryBinary : public GeometryComputer { | ||
|  | public: | ||
|  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | ||
|  |                            Context& context, CommandBuffer& res) const override { | ||
|  |         auto input0     = inputs[0]; | ||
|  |         auto input1     = inputs[1]; | ||
|  |         auto output     = outputs[0]; | ||
|  |         auto inputL0    = input0->elementSize(); | ||
|  |         auto inputL1    = input1->elementSize(); | ||
|  |         auto outputSize = output->elementSize(); | ||
|  |         MNN_ASSERT(0 != inputL1 && 0 != inputL0 && 0 != outputSize); | ||
|  |         if (1 == inputL0 || 1 == inputL1) { | ||
|  |             // Can directly compute
 | ||
|  |             Command cmd; | ||
|  |             cmd.op      = op; | ||
|  |             cmd.inputs  = {input0, input1}; | ||
|  |             cmd.outputs = std::move(outputs); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |             return true; | ||
|  |         } | ||
|  |         // Need Broadcast or same shape
 | ||
|  |         if (outputSize != inputL0) { | ||
|  |             std::shared_ptr<Tensor> newTensor(new Tensor); | ||
|  |             TensorUtils::copyShape(output, newTensor.get(), true); | ||
|  |             newTensor->buffer().type = output->buffer().type; | ||
|  |             ConvertUtils::broadcastto(input0, newTensor.get()); | ||
|  |             input0 = newTensor.get(); | ||
|  |             res.extras.emplace_back(newTensor); | ||
|  |         } | ||
|  |         if (outputSize != inputL1) { | ||
|  |             std::shared_ptr<Tensor> newTensor(new Tensor); | ||
|  |             TensorUtils::copyShape(output, newTensor.get(), true); | ||
|  |             newTensor->buffer().type = output->buffer().type; | ||
|  |             ConvertUtils::broadcastto(input1, newTensor.get()); | ||
|  |             input1 = newTensor.get(); | ||
|  |             res.extras.emplace_back(newTensor); | ||
|  |         } | ||
|  |         Command cmd; | ||
|  |         cmd.op      = op; | ||
|  |         cmd.inputs  = {input0, input1}; | ||
|  |         cmd.outputs = std::move(outputs); | ||
|  |         res.command.emplace_back(std::move(cmd)); | ||
|  |         return true; | ||
|  |     } | ||
|  | }; | ||
|  | 
 | ||
|  | static void _create() { | ||
|  |     std::shared_ptr<GeometryComputer> comp(new GeometryBinary); | ||
|  |     GeometryComputer::registerGeometryComputer(comp, {OpType_BinaryOp}); | ||
|  | } | ||
|  | 
 | ||
|  | REGISTER_GEOMETRY(GeometryBinary, _create); | ||
|  | 
 | ||
|  | } // namespace MNN
 |