mirror of https://github.com/alibaba/MNN.git
243 lines
9.7 KiB
C++
243 lines
9.7 KiB
C++
//
|
|
// GeometryCosineSimilarity.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/07/13.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "geometry/GeometryComputer.hpp"
|
|
#include "core/OpCommonUtils.hpp"
|
|
#include "geometry/GeometryComputerUtils.hpp"
|
|
|
|
namespace MNN {
|
|
class GeometryCosineSimilarity : public GeometryComputer {
|
|
public:
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
|
|
MNN_ASSERT(3 <= inputs.size());
|
|
MNN_ASSERT(1 == outputs.size());
|
|
|
|
auto input0 = inputs[0];
|
|
auto input1 = inputs[1];
|
|
auto dimTensor = inputs[2];
|
|
const auto dim = dimTensor->host<int32_t>()[0];
|
|
MNN_ASSERT(dim == 1);
|
|
auto output = outputs[0];
|
|
|
|
int dimensions = input0->dimensions();
|
|
int outside = 1;
|
|
int channel = 1;
|
|
int inside = 1;
|
|
for(int i=0; i<dim; i++) {
|
|
outside *= input0->length(i);
|
|
}
|
|
channel = input0->length(dim);
|
|
for(int i=dim+1; i<dimensions; i++) {
|
|
inside *= input0->length(i);
|
|
}
|
|
auto dimType = input0->getDimensionType();
|
|
|
|
|
|
//input0 transform to NCHW format
|
|
std::shared_ptr<Tensor> tmpInput0;
|
|
{
|
|
tmpInput0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
|
|
auto outputDes = TensorUtils::getDescribe(tmpInput0.get());
|
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
|
|
Tensor::InsideDescribe::Region desReg;
|
|
desReg.size[0] = outside;
|
|
desReg.size[1] = channel;
|
|
desReg.size[2] = inside;
|
|
desReg.dst.offset = 0;
|
|
desReg.dst.stride[0] = channel*inside;
|
|
desReg.dst.stride[1] = inside;
|
|
desReg.dst.stride[2] = 1;
|
|
desReg.src.offset = 0;
|
|
desReg.src.stride[0] = channel*inside;
|
|
desReg.src.stride[1] = inside;
|
|
desReg.src.stride[2] = 1;
|
|
desReg.origin = input0;
|
|
outputDes->regions.emplace_back(std::move(desReg));
|
|
|
|
res.extras.emplace_back(tmpInput0);
|
|
}
|
|
|
|
//input1 transform to NCHW format
|
|
std::shared_ptr<Tensor> tmpInput1;
|
|
{
|
|
tmpInput1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
|
|
auto outputDes = TensorUtils::getDescribe(tmpInput1.get());
|
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
outputDes->dimensionFormat = MNN_DATA_FORMAT_NCHW;
|
|
|
|
Tensor::InsideDescribe::Region desReg;
|
|
desReg.size[0] = outside;
|
|
desReg.size[1] = channel;
|
|
desReg.size[2] = inside;
|
|
desReg.dst.offset = 0;
|
|
desReg.dst.stride[0] = channel*inside;
|
|
desReg.dst.stride[1] = inside;
|
|
desReg.dst.stride[2] = 1;
|
|
desReg.src.offset = 0;
|
|
desReg.src.stride[0] = channel*inside;
|
|
desReg.src.stride[1] = inside;
|
|
desReg.src.stride[2] = 1;
|
|
desReg.origin = input1;
|
|
outputDes->regions.emplace_back(std::move(desReg));
|
|
|
|
res.extras.emplace_back(tmpInput1);
|
|
}
|
|
|
|
//input0*input0
|
|
std::shared_ptr<Tensor> tmpInput0x0;
|
|
{
|
|
tmpInput0x0.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(tmpInput0x0.get());
|
|
des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
|
|
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput0.get(), tmpInput0x0.get());
|
|
|
|
res.extras.emplace_back(tmpInput0x0);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//input0*input1
|
|
std::shared_ptr<Tensor> tmpInput0x1;
|
|
{
|
|
tmpInput0x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(tmpInput0x1.get());
|
|
des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
|
|
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput0.get(), tmpInput1.get(), tmpInput0x1.get());
|
|
|
|
res.extras.emplace_back(tmpInput0x1);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//input1*input1
|
|
std::shared_ptr<Tensor> tmpInput1x1;
|
|
{
|
|
tmpInput1x1.reset(Tensor::createDevice<float>({outside, channel, inside}, dimType));
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, tmpInput1.get(), tmpInput1.get(), tmpInput1x1.get());
|
|
|
|
res.extras.emplace_back(tmpInput1x1);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//reduction sum, axis=1, only support NCHW
|
|
std::shared_ptr<Tensor> sumValue0x0;
|
|
{
|
|
sumValue0x0.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(sumValue0x0.get());
|
|
auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x0.get(), sumValue0x0.get());
|
|
res.extras.emplace_back(sumValue0x0);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//reduction sum, axis=1, only support NCHW
|
|
std::shared_ptr<Tensor> sumValue0x1;
|
|
{
|
|
sumValue0x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(sumValue0x1.get());
|
|
auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput0x1.get(), sumValue0x1.get());
|
|
res.extras.emplace_back(sumValue0x1);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//reduction sum, axis=1, only support NCHW
|
|
std::shared_ptr<Tensor> sumValue1x1;
|
|
{
|
|
sumValue1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(sumValue1x1.get());
|
|
|
|
auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, tmpInput1x1.get(), sumValue1x1.get());
|
|
|
|
res.extras.emplace_back(sumValue1x1);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//sumValue0x0 * sumValue1x1
|
|
std::shared_ptr<Tensor> mulValue0x0_1x1;
|
|
{
|
|
mulValue0x0_1x1.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(mulValue0x0_1x1.get());
|
|
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, sumValue0x0.get(), sumValue1x1.get(), mulValue0x0_1x1.get());
|
|
|
|
res.extras.emplace_back(mulValue0x0_1x1);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//add eps
|
|
std::shared_ptr<Tensor> mulValue0x0_1x1_eps;
|
|
{
|
|
mulValue0x0_1x1_eps.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(mulValue0x0_1x1_eps.get());
|
|
|
|
const float eps = 1e-8f;
|
|
auto epsTensor = context.allocConst(op, {1}, halide_type_of<float>());
|
|
epsTensor.get()->host<float>()[0] = eps;
|
|
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, mulValue0x0_1x1.get(), epsTensor.get(), mulValue0x0_1x1_eps.get());
|
|
|
|
res.extras.emplace_back(mulValue0x0_1x1_eps);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
//sqrt(sumValue0x0 * sumValue1x1 + eps)
|
|
std::shared_ptr<Tensor> sqrtMulValue;
|
|
{
|
|
sqrtMulValue.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(sqrtMulValue.get());
|
|
|
|
auto cmd = GeometryComputerUtils::makeUnary(UnaryOpOperation_SQRT, mulValue0x0_1x1_eps.get(), sqrtMulValue.get());
|
|
|
|
res.extras.emplace_back(sqrtMulValue);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
//div
|
|
std::shared_ptr<Tensor> tmpOutput;
|
|
{
|
|
tmpOutput.reset(Tensor::createDevice<float>({outside, 1, inside}, dimType));
|
|
auto des = TensorUtils::getDescribe(tmpOutput.get());
|
|
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_REALDIV, sumValue0x1.get(), sqrtMulValue.get(), tmpOutput.get());
|
|
|
|
res.extras.emplace_back(tmpOutput);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
//transform to output
|
|
{
|
|
auto outputDes = TensorUtils::getDescribe(output);
|
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
Tensor::InsideDescribe::Region desReg;
|
|
desReg.size[0] = 1;
|
|
desReg.size[1] = outside;
|
|
desReg.size[2] = inside;
|
|
desReg.dst.offset = 0;
|
|
desReg.dst.stride[0] = outside*inside;
|
|
desReg.dst.stride[1] = inside;
|
|
desReg.dst.stride[2] = 1;
|
|
desReg.src.offset = 0;
|
|
desReg.src.stride[0] = outside*inside;
|
|
desReg.src.stride[1] = inside;
|
|
desReg.src.stride[2] = 1;
|
|
desReg.origin = tmpOutput.get();
|
|
outputDes->regions.emplace_back(std::move(desReg));
|
|
}
|
|
return true;
|
|
|
|
}
|
|
};
|
|
|
|
static void _create() {
|
|
std::shared_ptr<GeometryComputer> comp(new GeometryCosineSimilarity);
|
|
GeometryComputer::registerGeometryComputer(comp, {OpType_CosineSimilarity});
|
|
}
|
|
|
|
REGISTER_GEOMETRY(GeometryCosineSimilarity, _create);
|
|
|
|
} // namespace MNN
|