mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			249 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			249 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  GeometryCosineSimilarity.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2020/07/13.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include "geometry/GeometryComputer.hpp"
 | ||
|  | #include "core/OpCommonUtils.hpp"
 | ||
|  | #include "geometry/GeometryComputerUtils.hpp"
 | ||
|  | 
 | ||
|  | namespace MNN { | ||
|  | class GeometryCosineSimilarity : public GeometryComputer { | ||
|  | public: | ||
|  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, | ||
|  |                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override { | ||
|  |         MNN_ASSERT(3 <= inputs.size()); | ||
|  |         MNN_ASSERT(1 == outputs.size()); | ||
|  | 
 | ||
|  |         auto input0          = inputs[0]; | ||
|  |         auto input1          = inputs[1]; | ||
|  |         auto dimTensor = inputs[2]; | ||
|  |         const auto dim = dimTensor->host<int32_t>()[0]; | ||
|  |         MNN_ASSERT(dim == 1); | ||
|  |         auto output          = outputs[0]; | ||
|  |          | ||
|  |         int dimensions = input0->dimensions(); | ||
|  |         int outside = 1; | ||
|  |         int channel = 1; | ||
|  |         int inside = 1; | ||
|  |         for(int i=0; i<dim; i++) { | ||
|  |             outside *= input0->length(i); | ||
|  |         } | ||
|  |         channel = input0->length(dim); | ||
|  |         for(int i=dim+1; i<dimensions; i++) { | ||
|  |             inside *= input0->length(i); | ||
|  |         } | ||
|  |          | ||
|  |              | ||
|  |         //input0 transform to NCHW format
 | ||
|  |         std::shared_ptr<Tensor> tmpInput0; | ||
|  |         { | ||
|  |             tmpInput0.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto outputDes = TensorUtils::getDescribe(tmpInput0.get()); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  | 
 | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = outside; | ||
|  |             desReg.size[1] = channel; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = channel*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = channel*inside; | ||
|  |             desReg.src.stride[1] = inside; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = input0; | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  | 
 | ||
|  |             res.extras.emplace_back(tmpInput0); | ||
|  |         } | ||
|  |          | ||
|  |         //input1 transform to NCHW format
 | ||
|  |         std::shared_ptr<Tensor> tmpInput1; | ||
|  |         { | ||
|  |             tmpInput1.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto outputDes = TensorUtils::getDescribe(tmpInput1.get()); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |             outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  | 
 | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = outside; | ||
|  |             desReg.size[1] = channel; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = channel*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = channel*inside; | ||
|  |             desReg.src.stride[1] = inside; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = input1; | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  |              | ||
|  |             res.extras.emplace_back(tmpInput1); | ||
|  |         } | ||
|  |          | ||
|  |         //input0*input0
 | ||
|  |         std::shared_ptr<Tensor> tmpInput0x0; | ||
|  |         { | ||
|  |             tmpInput0x0.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(tmpInput0x0.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput0.get(), tmpInput0x0.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(tmpInput0x0); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |          | ||
|  |         //input0*input1
 | ||
|  |         std::shared_ptr<Tensor> tmpInput0x1; | ||
|  |         { | ||
|  |             tmpInput0x1.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(tmpInput0x1.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput1.get(), tmpInput0x1.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(tmpInput0x1); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |          | ||
|  |         //input1*input1
 | ||
|  |         std::shared_ptr<Tensor> tmpInput1x1; | ||
|  |         { | ||
|  |             tmpInput1x1.reset(Tensor::createDevice<float>({outside, channel, inside})); | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput1.get(), tmpInput1.get(), tmpInput1x1.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(tmpInput1x1); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  | 
 | ||
|  |         //reduction sum, axis=1, only support NCHW
 | ||
|  |         std::shared_ptr<Tensor> sumValue0x0; | ||
|  |         { | ||
|  |             sumValue0x0.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(sumValue0x0.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x0.get(), sumValue0x0.get()); | ||
|  |             res.extras.emplace_back(sumValue0x0); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  | 
 | ||
|  |         //reduction sum, axis=1, only support NCHW
 | ||
|  |         std::shared_ptr<Tensor> sumValue0x1; | ||
|  |         { | ||
|  |             sumValue0x1.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(sumValue0x1.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x1.get(), sumValue0x1.get()); | ||
|  |             res.extras.emplace_back(sumValue0x1); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |          | ||
|  |         //reduction sum, axis=1, only support NCHW
 | ||
|  |         std::shared_ptr<Tensor> sumValue1x1; | ||
|  |         { | ||
|  |             sumValue1x1.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(sumValue1x1.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput1x1.get(), sumValue1x1.get()); | ||
|  |          | ||
|  |             res.extras.emplace_back(sumValue1x1); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |          | ||
|  |         //sumValue0x0 * sumValue1x1
 | ||
|  |         std::shared_ptr<Tensor> mulValue0x0_1x1; | ||
|  |         { | ||
|  |             mulValue0x0_1x1.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(mulValue0x0_1x1.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, sumValue0x0.get(), sumValue1x1.get(), mulValue0x0_1x1.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(mulValue0x0_1x1); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |          | ||
|  |         //add eps
 | ||
|  |         std::shared_ptr<Tensor> mulValue0x0_1x1_eps; | ||
|  |         { | ||
|  |             mulValue0x0_1x1_eps.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(mulValue0x0_1x1_eps.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             const float eps         = 1e-8f; | ||
|  |             auto epsTensor = context.allocConst(op, {1}, halide_type_of<float>()); | ||
|  |             epsTensor.get()->host<float>()[0] = eps; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, mulValue0x0_1x1.get(), epsTensor.get(), mulValue0x0_1x1_eps.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(mulValue0x0_1x1_eps); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  | 
 | ||
|  |         //sqrt(sumValue0x0 * sumValue1x1 + eps)
 | ||
|  |         std::shared_ptr<Tensor> sqrtMulValue; | ||
|  |         { | ||
|  |             sqrtMulValue.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(sqrtMulValue.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_SQRT, mulValue0x0_1x1_eps.get(), sqrtMulValue.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(sqrtMulValue); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |         //div
 | ||
|  |         std::shared_ptr<Tensor> tmpOutput; | ||
|  |         { | ||
|  |             tmpOutput.reset(Tensor::createDevice<float>({outside, 1, inside})); | ||
|  |             auto des = TensorUtils::getDescribe(tmpOutput.get()); | ||
|  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | ||
|  |              | ||
|  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, sumValue0x1.get(), sqrtMulValue.get(), tmpOutput.get()); | ||
|  |              | ||
|  |             res.extras.emplace_back(tmpOutput); | ||
|  |             res.command.emplace_back(std::move(cmd)); | ||
|  |         } | ||
|  |         //transform to output
 | ||
|  |         { | ||
|  |             auto outputDes = TensorUtils::getDescribe(output); | ||
|  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | ||
|  |             Tensor::InsideDescribe::Region desReg; | ||
|  |             desReg.size[0] = 1; | ||
|  |             desReg.size[1] = outside; | ||
|  |             desReg.size[2] = inside; | ||
|  |             desReg.dst.offset = 0; | ||
|  |             desReg.dst.stride[0] = outside*inside; | ||
|  |             desReg.dst.stride[1] = inside; | ||
|  |             desReg.dst.stride[2] = 1; | ||
|  |             desReg.src.offset = 0; | ||
|  |             desReg.src.stride[0] = outside*inside; | ||
|  |             desReg.src.stride[1] = inside; | ||
|  |             desReg.src.stride[2] = 1; | ||
|  |             desReg.origin = tmpOutput.get(); | ||
|  |             outputDes->regions.emplace_back(std::move(desReg)); | ||
|  |         } | ||
|  |         return true; | ||
|  |          | ||
|  |     } | ||
|  | }; | ||
|  | 
 | ||
|  | static void _create() { | ||
|  |     std::shared_ptr<GeometryComputer> comp(new GeometryCosineSimilarity); | ||
|  |     GeometryComputer::registerGeometryComputer(comp, {OpType_CosineSimilarity}); | ||
|  | } | ||
|  | 
 | ||
|  | REGISTER_GEOMETRY(GeometryCosineSimilarity, _create); | ||
|  | 
 | ||
|  | } // namespace MNN
 |