| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  GeometryCosineSimilarity.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2020/07/13.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "geometry/GeometryComputer.hpp"
 | 
					
						
							|  |  |  | #include "core/OpCommonUtils.hpp"
 | 
					
						
							|  |  |  | #include "geometry/GeometryComputerUtils.hpp"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | class GeometryCosineSimilarity : public GeometryComputer { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                     const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(3 <= inputs.size()); | 
					
						
							|  |  |  |         MNN_ASSERT(1 == outputs.size()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto input0          = inputs[0]; | 
					
						
							|  |  |  |         auto input1          = inputs[1]; | 
					
						
							|  |  |  |         auto dimTensor = inputs[2]; | 
					
						
							|  |  |  |         const auto dim = dimTensor->host<int32_t>()[0]; | 
					
						
							|  |  |  |         MNN_ASSERT(dim == 1); | 
					
						
							|  |  |  |         auto output          = outputs[0]; | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         int dimensions = input0->dimensions(); | 
					
						
							|  |  |  |         int outside = 1; | 
					
						
							|  |  |  |         int channel = 1; | 
					
						
							|  |  |  |         int inside = 1; | 
					
						
							|  |  |  |         for(int i=0; i<dim; i++) { | 
					
						
							|  |  |  |             outside *= input0->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         channel = input0->length(dim); | 
					
						
							|  |  |  |         for(int i=dim+1; i<dimensions; i++) { | 
					
						
							|  |  |  |             inside *= input0->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         auto dimType = input0->getDimensionType(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |          | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |         //input0 transform to NCHW format
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> tmpInput0; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             tmpInput0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto outputDes = TensorUtils::getDescribe(tmpInput0.get()); | 
					
						
							|  |  |  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Tensor::InsideDescribe::Region desReg; | 
					
						
							|  |  |  |             desReg.size[0] = outside; | 
					
						
							|  |  |  |             desReg.size[1] = channel; | 
					
						
							|  |  |  |             desReg.size[2] = inside; | 
					
						
							|  |  |  |             desReg.dst.offset = 0; | 
					
						
							|  |  |  |             desReg.dst.stride[0] = channel*inside; | 
					
						
							|  |  |  |             desReg.dst.stride[1] = inside; | 
					
						
							|  |  |  |             desReg.dst.stride[2] = 1; | 
					
						
							|  |  |  |             desReg.src.offset = 0; | 
					
						
							|  |  |  |             desReg.src.stride[0] = channel*inside; | 
					
						
							|  |  |  |             desReg.src.stride[1] = inside; | 
					
						
							|  |  |  |             desReg.src.stride[2] = 1; | 
					
						
							|  |  |  |             desReg.origin = input0; | 
					
						
							|  |  |  |             outputDes->regions.emplace_back(std::move(desReg)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             res.extras.emplace_back(tmpInput0); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //input1 transform to NCHW format
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> tmpInput1; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             tmpInput1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto outputDes = TensorUtils::getDescribe(tmpInput1.get()); | 
					
						
							|  |  |  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |             outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Tensor::InsideDescribe::Region desReg; | 
					
						
							|  |  |  |             desReg.size[0] = outside; | 
					
						
							|  |  |  |             desReg.size[1] = channel; | 
					
						
							|  |  |  |             desReg.size[2] = inside; | 
					
						
							|  |  |  |             desReg.dst.offset = 0; | 
					
						
							|  |  |  |             desReg.dst.stride[0] = channel*inside; | 
					
						
							|  |  |  |             desReg.dst.stride[1] = inside; | 
					
						
							|  |  |  |             desReg.dst.stride[2] = 1; | 
					
						
							|  |  |  |             desReg.src.offset = 0; | 
					
						
							|  |  |  |             desReg.src.stride[0] = channel*inside; | 
					
						
							|  |  |  |             desReg.src.stride[1] = inside; | 
					
						
							|  |  |  |             desReg.src.stride[2] = 1; | 
					
						
							|  |  |  |             desReg.origin = input1; | 
					
						
							|  |  |  |             outputDes->regions.emplace_back(std::move(desReg)); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(tmpInput1); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //input0*input0
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> tmpInput0x0; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             tmpInput0x0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(tmpInput0x0.get()); | 
					
						
							|  |  |  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput0.get(), tmpInput0x0.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(tmpInput0x0); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //input0*input1
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> tmpInput0x1; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             tmpInput0x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(tmpInput0x1.get()); | 
					
						
							|  |  |  |             des->dimensionFormat = MNN_DATA_FORMAT_NCHW; | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput1.get(), tmpInput0x1.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(tmpInput0x1); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //input1*input1
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> tmpInput1x1; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             tmpInput1x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput1.get(), tmpInput1.get(), tmpInput1x1.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(tmpInput1x1); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         //reduction sum, axis=1, only support NCHW
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> sumValue0x0; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             sumValue0x0.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(sumValue0x0.get()); | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x0.get(), sumValue0x0.get()); | 
					
						
							|  |  |  |             res.extras.emplace_back(sumValue0x0); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         //reduction sum, axis=1, only support NCHW
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> sumValue0x1; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             sumValue0x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(sumValue0x1.get()); | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x1.get(), sumValue0x1.get()); | 
					
						
							|  |  |  |             res.extras.emplace_back(sumValue0x1); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //reduction sum, axis=1, only support NCHW
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> sumValue1x1; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             sumValue1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(sumValue1x1.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput1x1.get(), sumValue1x1.get()); | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |             res.extras.emplace_back(sumValue1x1); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //sumValue0x0 * sumValue1x1
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> mulValue0x0_1x1; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             mulValue0x0_1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(mulValue0x0_1x1.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, sumValue0x0.get(), sumValue1x1.get(), mulValue0x0_1x1.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(mulValue0x0_1x1); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         //add eps
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> mulValue0x0_1x1_eps; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             mulValue0x0_1x1_eps.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(mulValue0x0_1x1_eps.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             const float eps         = 1e-8f; | 
					
						
							|  |  |  |             auto epsTensor = context.allocConst(op, {1}, halide_type_of<float>()); | 
					
						
							|  |  |  |             epsTensor.get()->host<float>()[0] = eps; | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, mulValue0x0_1x1.get(), epsTensor.get(), mulValue0x0_1x1_eps.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(mulValue0x0_1x1_eps); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         //sqrt(sumValue0x0 * sumValue1x1 + eps)
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> sqrtMulValue; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             sqrtMulValue.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(sqrtMulValue.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_SQRT, mulValue0x0_1x1_eps.get(), sqrtMulValue.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(sqrtMulValue); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         //div
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> tmpOutput; | 
					
						
							|  |  |  |         { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             tmpOutput.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(tmpOutput.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, sumValue0x1.get(), sqrtMulValue.get(), tmpOutput.get()); | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             res.extras.emplace_back(tmpOutput); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         //transform to output
 | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             auto outputDes = TensorUtils::getDescribe(output); | 
					
						
							|  |  |  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |             Tensor::InsideDescribe::Region desReg; | 
					
						
							|  |  |  |             desReg.size[0] = 1; | 
					
						
							|  |  |  |             desReg.size[1] = outside; | 
					
						
							|  |  |  |             desReg.size[2] = inside; | 
					
						
							|  |  |  |             desReg.dst.offset = 0; | 
					
						
							|  |  |  |             desReg.dst.stride[0] = outside*inside; | 
					
						
							|  |  |  |             desReg.dst.stride[1] = inside; | 
					
						
							|  |  |  |             desReg.dst.stride[2] = 1; | 
					
						
							|  |  |  |             desReg.src.offset = 0; | 
					
						
							|  |  |  |             desReg.src.stride[0] = outside*inside; | 
					
						
							|  |  |  |             desReg.src.stride[1] = inside; | 
					
						
							|  |  |  |             desReg.src.stride[2] = 1; | 
					
						
							|  |  |  |             desReg.origin = tmpOutput.get(); | 
					
						
							|  |  |  |             outputDes->regions.emplace_back(std::move(desReg)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static void _create() { | 
					
						
							|  |  |  |     std::shared_ptr<GeometryComputer> comp(new GeometryCosineSimilarity); | 
					
						
							|  |  |  |     GeometryComputer::registerGeometryComputer(comp, {OpType_CosineSimilarity}); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_GEOMETRY(GeometryCosineSimilarity, _create); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 |