mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			215 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			215 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  GeometrySoftmax.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2020/06/28.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include "geometry/GeometryComputer.hpp"
 | |
| #include "core/OpCommonUtils.hpp"
 | |
| #include "geometry/GeometryComputerUtils.hpp"
 | |
| 
 | |
| namespace MNN {
 | |
| class GeometrySoftmax : public GeometryComputer {
 | |
| public:
 | |
|     virtual std::vector<bool> onGetOutputVirtual(const Op* op, const std::vector<Tensor*>& inputs,
 | |
|                                                  const std::vector<Tensor*>& outputs) const override {
 | |
|         auto  axis = op->main_as_Axis()->axis();
 | |
|         if (axis < 0) {
 | |
|             axis = inputs[0]->dimensions() + axis;
 | |
|         }
 | |
|         
 | |
|         if (axis == 1) {
 | |
|             return std::vector<bool>(outputs.size(), false);
 | |
|         }
 | |
|         return std::vector<bool>(outputs.size(), true);
 | |
|     }
 | |
|     
 | |
|     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
 | |
|                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
 | |
|         MNN_ASSERT(1 == inputs.size());
 | |
|         MNN_ASSERT(1 == outputs.size());
 | |
| 
 | |
|         auto input     = inputs[0];
 | |
|         auto output    = outputs[0];
 | |
|         auto dims      = input->buffer().dimensions;
 | |
|         
 | |
|         auto  axis = op->main_as_Axis()->axis();
 | |
|         if (axis < 0) {
 | |
|             axis = inputs[0]->dimensions() + axis;
 | |
|         }
 | |
|         
 | |
|         if (axis == 1) {
 | |
|             Command cmd;
 | |
|             cmd.op      = op;
 | |
|             cmd.inputs  = std::move(inputs);
 | |
|             cmd.outputs = std::move(outputs);
 | |
|             res.command.emplace_back(std::move(cmd));
 | |
|             return true;
 | |
|         }
 | |
|         
 | |
|         int inside  = 1;
 | |
|         int outside = 1;
 | |
|         int channel = 1;
 | |
|         for (int i = 0; i < axis; ++i) {
 | |
|             outside *= input->length(i);
 | |
|         }
 | |
|         channel = input->length(axis);
 | |
|         for (int i = axis + 1; i < dims; ++i) {
 | |
|             inside *= input->length(i);
 | |
|         }
 | |
| 
 | |
|         //input transform to NCHW format
 | |
|         std::shared_ptr<Tensor> tmpInput;
 | |
|         {
 | |
|             tmpInput.reset(Tensor::createDevice<float>({outside, channel, inside}));
 | |
|             auto outputDes = TensorUtils::getDescribe(tmpInput.get());
 | |
|             outputDes->regions.clear();
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
| 
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = outside;
 | |
|             desReg.size[1] = channel;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = channel*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = channel*inside;
 | |
|             desReg.src.stride[1] = inside;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = input;
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
|             
 | |
|             res.extras.emplace_back(tmpInput);
 | |
|         }
 | |
|         
 | |
|         //reduction max, axis=1
 | |
|         std::shared_ptr<Tensor> maxValue;
 | |
|         {
 | |
|             maxValue.reset(Tensor::createDevice<float>({outside, 1, inside}));
 | |
|             res.extras.emplace_back(maxValue);
 | |
|             res.command.emplace_back(GeometryComputerUtils::makeReduce(ReductionType_MAXIMUM, tmpInput.get(), maxValue.get()));
 | |
|         }
 | |
|         
 | |
|         //broadcast reduction axis dim
 | |
|         std::shared_ptr<Tensor> maxBroadValue;
 | |
|         {
 | |
|             maxBroadValue.reset(Tensor::createDevice<float>({outside, channel, inside}));
 | |
|             auto outputDes = TensorUtils::getDescribe(maxBroadValue.get());
 | |
|             outputDes->regions.clear();
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = outside;
 | |
|             desReg.size[1] = channel;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = channel*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = inside;
 | |
|             desReg.src.stride[1] = 0;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = maxValue.get();
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
|             
 | |
|             res.extras.emplace_back(maxBroadValue);
 | |
|         }
 | |
| 
 | |
|         //sub
 | |
|         std::shared_ptr<Tensor> subMaxValue;
 | |
|         {
 | |
|             subMaxValue.reset(Tensor::createDevice<float>({outside, channel, inside}));
 | |
|             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_SUB, tmpInput.get(), maxBroadValue.get(), subMaxValue.get());
 | |
|             res.extras.emplace_back(subMaxValue);
 | |
|             res.command.emplace_back(std::move(cmd));
 | |
|         }
 | |
|         //exp
 | |
|         std::shared_ptr<Tensor> expValue;
 | |
|         {
 | |
|             expValue.reset(Tensor::createDevice<float>({outside, channel, inside}));
 | |
|             auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_EXP, subMaxValue.get(), expValue.get());
 | |
|             res.extras.emplace_back(expValue);
 | |
|             res.command.emplace_back(std::move(cmd));
 | |
|             
 | |
|         }
 | |
|         
 | |
|         //reduction sum, axis=2, only support NCHW
 | |
|         std::shared_ptr<Tensor> sumValue;
 | |
|         {
 | |
|             sumValue.reset(Tensor::createDevice<float>({outside, 1, inside}));
 | |
|             res.extras.emplace_back(sumValue);
 | |
|             res.command.emplace_back(GeometryComputerUtils::makeReduce(ReductionType_SUM, expValue.get(), sumValue.get()));
 | |
|         }
 | |
|         
 | |
|         //broadcast reduction axis dim
 | |
|         std::shared_ptr<Tensor> sumBroadValue;
 | |
|         {
 | |
|             sumBroadValue.reset(Tensor::createDevice<float>({outside, channel, inside}));
 | |
|             auto outputDes = TensorUtils::getDescribe(sumBroadValue.get());
 | |
|             outputDes->regions.clear();
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = outside;
 | |
|             desReg.size[1] = channel;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = channel*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = inside;
 | |
|             desReg.src.stride[1] = 0;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = sumValue.get();
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
| 
 | |
|             res.extras.emplace_back(sumBroadValue);
 | |
|         }
 | |
| 
 | |
|         //div
 | |
|         std::shared_ptr<Tensor> tmpOutput;
 | |
|         {
 | |
|             tmpOutput.reset(Tensor::createDevice<float>({outside, channel, inside}));
 | |
|             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, expValue.get(), sumBroadValue.get(), tmpOutput.get());
 | |
|             res.extras.emplace_back(tmpOutput);
 | |
|             res.command.emplace_back(std::move(cmd));
 | |
|         }
 | |
| 
 | |
|         //transform to output
 | |
|         {
 | |
|             auto outputDes = TensorUtils::getDescribe(output);
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             Tensor::InsideDescribe::Region desReg;
 | |
|             desReg.size[0] = outside;
 | |
|             desReg.size[1] = channel;
 | |
|             desReg.size[2] = inside;
 | |
|             desReg.dst.offset = 0;
 | |
|             desReg.dst.stride[0] = channel*inside;
 | |
|             desReg.dst.stride[1] = inside;
 | |
|             desReg.dst.stride[2] = 1;
 | |
|             desReg.src.offset = 0;
 | |
|             desReg.src.stride[0] = channel*inside;
 | |
|             desReg.src.stride[1] = inside;
 | |
|             desReg.src.stride[2] = 1;
 | |
|             desReg.origin = tmpOutput.get();
 | |
|             outputDes->regions.emplace_back(std::move(desReg));
 | |
|         }
 | |
|         return true;
 | |
|     }
 | |
| };
 | |
| 
 | |
| static void _create() {
 | |
| //    std::shared_ptr<GeometryComputer> comp(new GeometrySoftmax);
 | |
| //    GeometryComputer::registerGeometryComputer(comp, {OpType_Softmax});
 | |
| }
 | |
| 
 | |
| REGISTER_GEOMETRY(GeometrySoftmax, _create);
 | |
| 
 | |
| } // namespace MNN
 |