mirror of https://github.com/alibaba/MNN.git
154 lines
5.5 KiB
C++
154 lines
5.5 KiB
C++
//
|
|
// GeometryInnerProduct.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/05/07.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "geometry/GeometryComputer.hpp"
|
|
#include "geometry/GeometryComputerUtils.hpp"
|
|
#include "core/OpCommonUtils.hpp"
|
|
#include "core/ConvolutionCommon.hpp"
|
|
#include "ConvertUtils.hpp"
|
|
#define MNN_OPEN_TIME_TRACE
|
|
#include <MNN/AutoTime.hpp>
|
|
namespace MNN {
|
|
class GeometryInnerProduct : public GeometryComputer {
|
|
public:
|
|
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
|
|
auto parameter = op->main_as_InnerProduct();
|
|
int outputCount = parameter->outputCount();
|
|
int srcCount = parameter->weight()->size() / outputCount;
|
|
|
|
MNN_ASSERT(inputs.size() == 1);
|
|
MNN_ASSERT(outputs.size() == 1);
|
|
auto input = inputs[0];
|
|
auto output = outputs[0];
|
|
int inputDims = input->dimensions();
|
|
int outputDims = output->dimensions();
|
|
MNN_ASSERT(inputDims >= 2);
|
|
MNN_ASSERT(outputDims == 2);
|
|
MNN_ASSERT(output->length(1) == outputCount);
|
|
|
|
int batch = output->length(0);
|
|
MNN_ASSERT(input->length(0) == batch);
|
|
int mulNum = 1;
|
|
for(int i=1; i < inputDims; i++) {
|
|
mulNum *= input->length(i);
|
|
}
|
|
if (srcCount != mulNum) {
|
|
return false;
|
|
}
|
|
|
|
Tensor* A = nullptr;
|
|
Tensor* B = nullptr;
|
|
{
|
|
std::shared_ptr<Tensor> tmpInput(new Tensor);
|
|
tmpInput->buffer().type = halide_type_of<float>();
|
|
tmpInput->buffer().dimensions = 2;
|
|
tmpInput->setLength(0, batch);
|
|
tmpInput->setLength(1, srcCount);
|
|
auto des = TensorUtils::getDescribe(tmpInput.get());
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
|
|
des->regions.clear();
|
|
des->regions.reserve(1);
|
|
|
|
Tensor::InsideDescribe::Region region;
|
|
region.origin = input;
|
|
region.size[0] = 1;
|
|
region.size[1] = batch;
|
|
region.size[2] = srcCount;
|
|
region.src.offset = 0;
|
|
region.dst.offset = 0;
|
|
region.src.stride[0] = 1;
|
|
region.dst.stride[0] = 1;
|
|
region.src.stride[1] = srcCount;
|
|
region.dst.stride[1] = srcCount;
|
|
region.src.stride[2] = 1;
|
|
region.dst.stride[2] = 1;
|
|
des->regions.emplace_back(std::move(region));
|
|
|
|
A = tmpInput.get();
|
|
res.extras.emplace_back(tmpInput);
|
|
}
|
|
|
|
std::shared_ptr<Tensor> tmpOutput(new Tensor);
|
|
std::shared_ptr<Tensor> C(new Tensor);
|
|
auto constTensors = context.searchConst(op);
|
|
Tensor* weight = nullptr;
|
|
Tensor* bias = nullptr;
|
|
if (!constTensors.empty()) {
|
|
MNN_ASSERT(constTensors.size() == 2);
|
|
weight = constTensors[0].get();
|
|
bias = constTensors[1].get();
|
|
} else {
|
|
auto weightTensor = context.allocConst(op, {outputCount, srcCount}, halide_type_of<float>());
|
|
::memcpy(weightTensor.get()->host<float>(), parameter->weight()->data(), parameter->weight()->size()*sizeof(float));
|
|
weight = weightTensor.get();
|
|
auto biasTensor = context.allocConst(op, {batch, outputCount}, halide_type_of<float>());
|
|
::memcpy(biasTensor.get()->host<float>(), parameter->bias()->data(), parameter->bias()->size()*sizeof(float));
|
|
bias = biasTensor.get();
|
|
}
|
|
{
|
|
B = weight;
|
|
|
|
C->buffer().type = halide_type_of<float>();
|
|
C->buffer().dimensions = 2;
|
|
C->setLength(0, batch);
|
|
C->setLength(1, outputCount);
|
|
|
|
auto cmd = GeometryComputerUtils::makeMatMul(A, B, C.get(), nullptr, false, true);
|
|
res.extras.emplace_back(C);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
{
|
|
tmpOutput->buffer().type = halide_type_of<float>();
|
|
tmpOutput->buffer().dimensions = 2;
|
|
tmpOutput->setLength(0, batch);
|
|
tmpOutput->setLength(1, outputCount);
|
|
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, C.get(), bias, tmpOutput.get());
|
|
res.extras.emplace_back(tmpOutput);
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
{
|
|
auto des = TensorUtils::getDescribe(output);
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->regions.clear();
|
|
des->regions.reserve(1);
|
|
|
|
Tensor::InsideDescribe::Region region;
|
|
region.origin = tmpOutput.get();
|
|
region.size[0] = 1;
|
|
region.size[1] = batch;
|
|
region.size[2] = outputCount;
|
|
region.src.offset = 0;
|
|
region.dst.offset = 0;
|
|
region.src.stride[0] = 1;
|
|
region.dst.stride[0] = 1;
|
|
region.src.stride[1] = outputCount;
|
|
region.dst.stride[1] = outputCount;
|
|
region.src.stride[2] = 1;
|
|
region.dst.stride[2] = 1;
|
|
des->regions.emplace_back(std::move(region));
|
|
}
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
static void _create() {
|
|
std::shared_ptr<GeometryComputer> comp(new GeometryInnerProduct);
|
|
GeometryComputer::registerGeometryComputer(comp, {OpType_InnerProduct});
|
|
}
|
|
|
|
REGISTER_GEOMETRY(GeometryInnerProduct, _create);
|
|
|
|
} // namespace MNN
|