mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			215 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			215 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  GeometrySoftmax.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2020/06/28.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include "geometry/GeometryComputer.hpp"
 | ||
|  | #include "core/OpCommonUtils.hpp"
 | ||
|  | #include "geometry/GeometryComputerUtils.hpp"
 | ||
|  | 
 | ||
|  | namespace MNN { | ||
|  | class GeometrySoftmax : public GeometryComputer { | ||
|  | public: | ||
|  |     virtual std::vector<bool> onGetOutputVirtual(const Op* op, const std::vector<Tensor*>& inputs, | ||
|  |                                                  const std::vector<Tensor*>& outputs) const override { | ||
|  |         auto  axis = op->main_as_Axis()->axis(); | ||
|  |         if (axis < 0) { | ||
|  |             axis = inputs[0]->dimensions() + axis; | ||
|  |         } | ||
|  |          | ||
|  |         if (axis == 1) { | ||
|  |             return std::vector<bool>(outputs.size(), false); | ||
|  |         } | ||
|  |         return std::vector<bool>(outputs.size(), true); | ||
|  |     } | ||
|  |      | ||
|  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, | ||
|  |                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override { | ||
|  |         MNN_ASSERT(1 == inputs.size()); | ||
|  |         MNN_ASSERT(1 == outputs.size()); | ||
|  | 
 | ||
|  |         auto input     = inputs[0]; | ||
|  |         auto output    = outputs[0]; | ||
|  |         auto dims      = input->buffer().dimensions; | ||
|  |          | ||
|  |         auto  axis = op->main_as_Axis()->axis(); | ||
|  |         if (axis < 0) { | ||
|  |             axis = inputs[0]->dimensions() + axis; | ||
|  |         } | ||
|  |          | ||
|  |         if (axis == 1) { | ||
|  |             Command cmd; | ||
|  |             cmd.op      = op; | ||
|  |             cmd.inputs  = std::move(inputs); | ||
|  |             cmd.outputs = std::move(outputs); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |             return true; | ||
|  |         } | ||
|  |          | ||
|  |         int inside  = 1; | ||
|  |         int outside = 1; | ||
|  |         int channel = 1; | ||
|  |         for (int i = 0; i < axis; ++i) { | ||
|  |             outside *= input->length(i); | ||
|  |         } | ||
|  |         channel = input->length(axis); | ||
|  |         for (int i = axis + 1; i < dims; ++i) { | ||
|  |             inside *= input->length(i); | ||
|  |         } | ||
|  | 
 | ||
|  |         //input transform to NCHW format
 | ||
|  |         std::shared_ptr<Tensor> tmpInput; | ||
|  |         { | ||
|  |             tmpInput.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto outputDes = TensorUtils::getDescribe(tmpInput.get()); | ||
|  |             outputDes->regions.clear(); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  | 
 | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = outside; | ||
|  |             desReg.size[1] = channel; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = channel*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = channel*inside; | ||
|  |             desReg.src.stride[1] = inside; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = input; | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  |              | ||
|  |             res.extras.emplace_back(tmpInput); | ||
|  |         } | ||
|  |          | ||
|  |         //reduction max, axis=1
 | ||
|  |         std::shared_ptr<Tensor> maxValue; | ||
|  |         { | ||
|  |             maxValue.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             res.extras.emplace_back(maxValue); | ||
|  |             res.command.emplace_back(GeometryComputerUtils::makeReduce(ReductionType_MAXIMUM, tmpInput.get(), maxValue.get())); | ||
|  |         } | ||
|  |          | ||
|  |         //broadcast reduction axis dim
 | ||
|  |         std::shared_ptr<Tensor> maxBroadValue; | ||
|  |         { | ||
|  |             maxBroadValue.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto outputDes = TensorUtils::getDescribe(maxBroadValue.get()); | ||
|  |             outputDes->regions.clear(); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |              | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = outside; | ||
|  |             desReg.size[1] = channel; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = channel*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = inside; | ||
|  |             desReg.src.stride[1] = 0; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = maxValue.get(); | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  |              | ||
|  |             res.extras.emplace_back(maxBroadValue); | ||
|  |         } | ||
|  | 
 | ||
|  |         //sub
 | ||
|  |         std::shared_ptr<Tensor> subMaxValue; | ||
|  |         { | ||
|  |             subMaxValue.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_SUB, tmpInput.get(), maxBroadValue.get(), subMaxValue.get()); | ||
|  |             res.extras.emplace_back(subMaxValue); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |         //exp
 | ||
|  |         std::shared_ptr<Tensor> expValue; | ||
|  |         { | ||
|  |             expValue.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_EXP, subMaxValue.get(), expValue.get()); | ||
|  |             res.extras.emplace_back(expValue); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |              | ||
|  |         } | ||
|  |          | ||
|  |         //reduction sum, axis=2, only support NCHW
 | ||
|  |         std::shared_ptr<Tensor> sumValue; | ||
|  |         { | ||
|  |             sumValue.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             res.extras.emplace_back(sumValue); | ||
|  |             res.command.emplace_back(GeometryComputerUtils::makeReduce(ReductionType_SUM, expValue.get(), sumValue.get())); | ||
|  |         } | ||
|  |          | ||
|  |         //broadcast reduction axis dim
 | ||
|  |         std::shared_ptr<Tensor> sumBroadValue; | ||
|  |         { | ||
|  |             sumBroadValue.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto outputDes = TensorUtils::getDescribe(sumBroadValue.get()); | ||
|  |             outputDes->regions.clear(); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |              | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = outside; | ||
|  |             desReg.size[1] = channel; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = channel*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = inside; | ||
|  |             desReg.src.stride[1] = 0; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = sumValue.get(); | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  | 
 | ||
|  |             res.extras.emplace_back(sumBroadValue); | ||
|  |         } | ||
|  | 
 | ||
|  |         //div
 | ||
|  |         std::shared_ptr<Tensor> tmpOutput; | ||
|  |         { | ||
|  |             tmpOutput.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, expValue.get(), sumBroadValue.get(), tmpOutput.get()); | ||
|  |             res.extras.emplace_back(tmpOutput); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  | 
 | ||
|  |         //transform to output
 | ||
|  |         { | ||
|  |             auto outputDes = TensorUtils::getDescribe(output); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = outside; | ||
|  |             desReg.size[1] = channel; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = channel*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = channel*inside; | ||
|  |             desReg.src.stride[1] = inside; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = tmpOutput.get(); | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  |         } | ||
|  |         return true; | ||
|  |     } | ||
|  | }; | ||
|  | 
 | ||
|  | static void _create() { | ||
|  | //    std::shared_ptr<GeometryComputer> comp(new GeometrySoftmax);
 | ||
|  | //    GeometryComputer::registerGeometryComputer(comp, {OpType_Softmax});
 | ||
|  | } | ||
|  | 
 | ||
|  | REGISTER_GEOMETRY(GeometrySoftmax, _create); | ||
|  | 
 | ||
|  | } // namespace MNN
 |