mirror of https://github.com/alibaba/MNN.git
305 lines
12 KiB
C++
305 lines
12 KiB
C++
//
|
|
// GeometryGather.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/06/09.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "geometry/GeometryComputer.hpp"
|
|
#include "core/OpCommonUtils.hpp"
|
|
namespace MNN {
|
|
|
|
class GeometryGather : public DefaultGeometryComputer {
|
|
public:
|
|
virtual std::vector<bool> onGetOutputVirtual(const Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs) const override {
|
|
MNN_ASSERT(inputs.size() == 2);
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto embedding = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
|
|
const int firstDimStride = embedding->buffer().dim[0].stride;
|
|
if (TensorUtils::getDescribe(indices)->usage == MNN::Tensor::InsideDescribe::CONSTANT && firstDimStride != 0) {
|
|
std::vector<bool> res(outputs.size(), true);
|
|
return res;
|
|
}
|
|
return std::vector<bool>(outputs.size(), false);
|
|
}
|
|
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& res) const override {
|
|
MNN_ASSERT(2 == inputs.size());
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto embedding = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
|
|
const int firstDimStride = embedding->buffer().dim[0].stride;
|
|
if (TensorUtils::getDescribe(indices)->usage != MNN::Tensor::InsideDescribe::CONSTANT || firstDimStride == 0) {
|
|
Command cmd;
|
|
cmd.op = op;
|
|
cmd.inputs = std::move(inputs);
|
|
cmd.outputs = std::move(outputs);
|
|
res.command.emplace_back(std::move(cmd));
|
|
return true;
|
|
}
|
|
|
|
auto bytes = embedding->buffer().type.bytes();
|
|
|
|
const size_t indicesCount = indices->elementSize();
|
|
const auto limit = embedding->length(0);
|
|
const int* indicesData = indices->host<int32_t>();
|
|
|
|
auto outputDes = TensorUtils::getDescribe(output);
|
|
outputDes->regions.clear();
|
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
for (int i = 0; i < indicesCount; i++) {
|
|
if (indicesData[i] < 0 || indicesData[i] > limit) {
|
|
MNN_PRINT("Gather indice error\n");
|
|
return false;
|
|
}
|
|
|
|
Tensor::InsideDescribe::Region slice;
|
|
slice.origin = embedding;
|
|
slice.size[0] = 1;
|
|
slice.size[1] = 1;
|
|
slice.size[2] = firstDimStride;
|
|
slice.src.offset = firstDimStride * indicesData[i];
|
|
slice.dst.offset = i * firstDimStride;
|
|
slice.src.stride[0] = 1;
|
|
slice.src.stride[1] = 1;
|
|
slice.src.stride[2] = 1;
|
|
slice.dst.stride[0] = 1;
|
|
slice.dst.stride[1] = 1;
|
|
slice.dst.stride[2] = 1;
|
|
outputDes->regions.emplace_back(std::move(slice));
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class GeometryGatherND : public DefaultGeometryComputer {
|
|
public:
|
|
virtual std::vector<bool> onGetOutputVirtual(const Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs) const override {
|
|
MNN_ASSERT(inputs.size() == 2);
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto params = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
|
|
int mSliceN = 1;
|
|
int mSliceSize = 1;
|
|
for (int i = 0; i < indices->dimensions() - 1; ++i) {
|
|
mSliceN *= indices->length(i);
|
|
}
|
|
auto indiceNd = indices->length(indices->dimensions() - 1);
|
|
std::vector<int> mDimsToCount;
|
|
mDimsToCount.resize(indiceNd);
|
|
for (int i = indiceNd; i < params->dimensions(); ++i) {
|
|
mSliceSize *= params->length(i);
|
|
}
|
|
|
|
if (TensorUtils::getDescribe(indices)->usage == MNN::Tensor::InsideDescribe::CONSTANT && mSliceSize != 0) {
|
|
std::vector<bool> res(outputs.size(), true);
|
|
return res;
|
|
} else {
|
|
std::vector<bool> res(outputs.size(), false);
|
|
return res;
|
|
}
|
|
}
|
|
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& res) const override {
|
|
MNN_ASSERT(2 == inputs.size());
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto params = inputs[0];
|
|
auto indice = inputs[1];
|
|
auto output = outputs[0];
|
|
|
|
int mSliceN = 1;
|
|
int mSliceSize = 1;
|
|
for (int i = 0; i < indice->dimensions() - 1; ++i) {
|
|
mSliceN *= indice->length(i);
|
|
}
|
|
auto indiceNd = indice->length(indice->dimensions() - 1);
|
|
std::vector<int> mDimsToCount;
|
|
mDimsToCount.resize(indiceNd);
|
|
for (int i = indiceNd; i < params->dimensions(); ++i) {
|
|
mSliceSize *= params->length(i);
|
|
}
|
|
|
|
if (TensorUtils::getDescribe(indice)->usage != MNN::Tensor::InsideDescribe::CONSTANT || mSliceSize == 0) {
|
|
Command cmd;
|
|
cmd.op = op;
|
|
cmd.inputs = std::move(inputs);
|
|
cmd.outputs = std::move(outputs);
|
|
res.command.emplace_back(std::move(cmd));
|
|
return true;
|
|
}
|
|
|
|
auto paramSize = params->elementSize();
|
|
for (int i = 0; i < indiceNd; ++i) {
|
|
mDimsToCount[i] = paramSize / params->length(i);
|
|
paramSize = mDimsToCount[i];
|
|
}
|
|
mDimsToCount.resize(indiceNd);
|
|
auto indiceData = indice->host<int32_t>();
|
|
|
|
auto outputDes = TensorUtils::getDescribe(output);
|
|
outputDes->regions.clear();
|
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
for (int i = 0; i < mSliceN; i++) {
|
|
int fromPos = 0;
|
|
for (int j = 0; j < indiceNd; ++j) {
|
|
fromPos += mDimsToCount[j] * indiceData[i * indiceNd + j];
|
|
}
|
|
|
|
Tensor::InsideDescribe::Region slice;
|
|
slice.origin = params;
|
|
slice.size[0] = 1;
|
|
slice.size[1] = 1;
|
|
slice.size[2] = mSliceSize;
|
|
slice.src.offset = fromPos;
|
|
slice.dst.offset = i * mSliceSize;
|
|
slice.src.stride[0] = 1;
|
|
slice.src.stride[1] = 1;
|
|
slice.src.stride[2] = 1;
|
|
slice.dst.stride[0] = 1;
|
|
slice.dst.stride[1] = 1;
|
|
slice.dst.stride[2] = 1;
|
|
outputDes->regions.emplace_back(std::move(slice));
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class GeometryGatherV2 : public DefaultGeometryComputer {
|
|
public:
|
|
virtual std::vector<bool> onGetOutputVirtual(const Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs) const override {
|
|
MNN_ASSERT(inputs.size() >= 2);
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto params = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
|
|
int axis = 0;
|
|
if (inputs.size() == 3) {
|
|
const Tensor* axisTensor = inputs[2];
|
|
axis = axisTensor->host<int32_t>()[0];
|
|
}
|
|
|
|
MNN_ASSERT(axis > -params->buffer().dimensions && axis < params->buffer().dimensions);
|
|
|
|
if (axis < 0) {
|
|
axis = params->buffer().dimensions + axis;
|
|
}
|
|
const int gatherDimSize = params->buffer().dim[axis].extent;
|
|
const int N = indices->elementSize();
|
|
MNN_ASSERT(gatherDimSize <= std::numeric_limits<int32_t>::max());
|
|
|
|
int inside = 1;
|
|
for (int i = axis + 1; i < params->dimensions(); ++i) {
|
|
inside *= params->length(i);
|
|
}
|
|
|
|
if (TensorUtils::getDescribe(indices)->usage == MNN::Tensor::InsideDescribe::CONSTANT && inside != 0) {
|
|
std::vector<bool> res(outputs.size(), true);
|
|
return res;
|
|
}
|
|
return std::vector<bool>(outputs.size(), false);
|
|
}
|
|
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& res) const override {
|
|
MNN_ASSERT(inputs.size() >= 2);
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto params = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
|
|
int axis = 0;
|
|
if (inputs.size() == 3) {
|
|
const Tensor* axisTensor = inputs[2];
|
|
axis = axisTensor->host<int32_t>()[0];
|
|
}
|
|
MNN_ASSERT(axis > -params->buffer().dimensions && axis < params->buffer().dimensions);
|
|
|
|
if (axis < 0) {
|
|
axis = params->buffer().dimensions + axis;
|
|
}
|
|
const int gatherDimSize = params->buffer().dim[axis].extent;
|
|
const int N = indices->elementSize();
|
|
MNN_ASSERT(gatherDimSize <= std::numeric_limits<int32_t>::max());
|
|
|
|
int inside = 1;
|
|
int outside = 1;
|
|
for (int i = 0; i < axis; ++i) {
|
|
outside *= params->length(i);
|
|
}
|
|
for (int i = axis + 1; i < params->dimensions(); ++i) {
|
|
inside *= params->length(i);
|
|
}
|
|
|
|
if (TensorUtils::getDescribe(indices)->usage != MNN::Tensor::InsideDescribe::CONSTANT || inside == 0) {
|
|
Command cmd;
|
|
cmd.op = op;
|
|
cmd.inputs = std::move(inputs);
|
|
cmd.outputs = std::move(outputs);
|
|
res.command.emplace_back(std::move(cmd));
|
|
return true;
|
|
}
|
|
|
|
const int limit = params->length(axis);
|
|
auto bytes = output->buffer().type.bytes();
|
|
const int insideStride = inside;
|
|
const int outputOutsideStride = inside * N;
|
|
const int inputOutsideStride = inside * inputs[0]->length(axis);
|
|
const int* indicesPtr = indices->host<int32_t>();
|
|
|
|
auto outputDes = TensorUtils::getDescribe(output);
|
|
outputDes->regions.clear();
|
|
outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
for (int o = 0; o < outside; ++o) {
|
|
for (int i = 0; i < N; i++) {
|
|
if (indicesPtr[i] < 0 || indicesPtr[i] > limit) {
|
|
continue;
|
|
}
|
|
Tensor::InsideDescribe::Region slice;
|
|
slice.origin = params;
|
|
slice.size[0] = 1;
|
|
slice.size[1] = 1;
|
|
slice.size[2] = insideStride;
|
|
slice.src.offset = inputOutsideStride * o + insideStride * indicesPtr[i];
|
|
slice.dst.offset = outputOutsideStride * o + i * insideStride;
|
|
slice.src.stride[0] = 1;
|
|
slice.src.stride[1] = 1;
|
|
slice.src.stride[2] = 1;
|
|
slice.dst.stride[0] = 1;
|
|
slice.dst.stride[1] = 1;
|
|
slice.dst.stride[2] = 1;
|
|
outputDes->regions.emplace_back(std::move(slice));
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
static void _create() {
|
|
// std::shared_ptr<GeometryComputer> comp(new GeometryGather);
|
|
// GeometryComputer::registerGeometryComputer(comp, {OpType_Gather});
|
|
//
|
|
// std::shared_ptr<GeometryComputer> comp2(new GeometryGatherND);
|
|
// GeometryComputer::registerGeometryComputer(comp2, {OpType_GatherND});
|
|
//
|
|
// std::shared_ptr<GeometryComputer> comp3(new GeometryGatherV2);
|
|
// GeometryComputer::registerGeometryComputer(comp3, {OpType_GatherV2});
|
|
}
|
|
|
|
REGISTER_GEOMETRY(GeometryGather, _create);
|
|
|
|
} // namespace MNN
|