mirror of https://github.com/alibaba/MNN.git
495 lines
21 KiB
C++
495 lines
21 KiB
C++
//
|
|
// GeometryGather.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/06/09.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "geometry/GeometryComputer.hpp"
|
|
#include "geometry/GeometryComputerUtils.hpp"
|
|
#include "core/OpCommonUtils.hpp"
|
|
namespace MNN {
|
|
static void _computeGather(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
GeometryComputer::Context& context, CommandBuffer& res, const Op* op) {
|
|
int axis = 0;
|
|
if (inputs.size() == 3) {
|
|
const Tensor *axisTensor = inputs[2];
|
|
axis = axisTensor->host<int32_t>()[0];
|
|
}
|
|
if (op->main_type() == OpParameter_Axis) {
|
|
axis = op->main_as_Axis()->axis();
|
|
}
|
|
auto params = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
MNN_ASSERT(axis > -params->buffer().dimensions && axis < params->buffer().dimensions);
|
|
if (axis < 0) {
|
|
axis = params->buffer().dimensions + axis;
|
|
}
|
|
const int gatherDimSize = params->buffer().dim[axis].extent;
|
|
const int N = indices->elementSize();
|
|
MNN_ASSERT(gatherDimSize <= std::numeric_limits<int32_t>::max());
|
|
|
|
int inside = 1;
|
|
int outside = 1;
|
|
for (int i = 0; i < axis; ++i) {
|
|
outside *= params->length(i);
|
|
}
|
|
for (int i = axis + 1; i < params->dimensions(); ++i) {
|
|
inside *= params->length(i);
|
|
}
|
|
const int limit = 3;
|
|
if (TensorUtils::getDescribe(indices)->usage == Tensor::InsideDescribe::CONSTANT && N < limit) {
|
|
// Use Raster instead of loop
|
|
auto outDes = TensorUtils::getDescribe(output);
|
|
outDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
outDes->regions.clear();
|
|
outDes->regions.reserve(N);
|
|
auto indicePtr = indices->host<int>();
|
|
auto axisLen = params->length(axis);
|
|
for (int i=0; i<N; ++i) {
|
|
if (indicePtr[i] < 0 || indicePtr[i] >= axisLen) {
|
|
continue;
|
|
}
|
|
Tensor::InsideDescribe::Region reg;
|
|
reg.origin = inputs[0];
|
|
reg.size[0] = 1;
|
|
reg.size[1] = outside;
|
|
reg.size[2] = inside;
|
|
reg.src.offset = indicePtr[i] * inside;
|
|
reg.src.stride[0] = 0;
|
|
reg.src.stride[1] = inside * axisLen;
|
|
reg.src.stride[2] = 1;
|
|
reg.dst.offset = i * inside;
|
|
reg.dst.stride[1] = inside * N;
|
|
reg.dst.stride[2] = 1;
|
|
outDes->regions.emplace_back(std::move(reg));
|
|
}
|
|
return;
|
|
}
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
OpBuilder unaryOp(builder);
|
|
unaryOp.add_type(OpType_UnaryOp);
|
|
auto unaryOpPffset = unaryOp.Finish();
|
|
auto iterIndexesOffset = builder.CreateVector(std::vector<int>{-1, 1});
|
|
auto stepOffset = builder.CreateVector(std::vector<int>{inside, inside});
|
|
auto indexesOffset = builder.CreateVector(std::vector<int>{2, 0});
|
|
auto sizeOffset = builder.CreateVector(std::vector<int>{outside, 1, inside});
|
|
// View 0
|
|
auto view0Stride = builder.CreateVector(std::vector<int>{inside * N, inside, 1});
|
|
ViewBuilder view0Builder(builder);
|
|
view0Builder.add_offset(0);
|
|
view0Builder.add_stride(view0Stride);
|
|
auto view0Offset = view0Builder.Finish();
|
|
// View1
|
|
auto view1Stride = builder.CreateVector(std::vector<int>{inside * params->length(axis), inside, 1});
|
|
ViewBuilder view1Builder(builder);
|
|
view1Builder.add_offset(0);
|
|
view1Builder.add_stride(view1Stride);
|
|
auto view1Offset = view1Builder.Finish();
|
|
auto viewAllOffset = builder.CreateVector<flatbuffers::Offset<View>>({view0Offset, view1Offset});
|
|
|
|
RegionCommandBuilder rcmdBuild(builder);
|
|
rcmdBuild.add_op(unaryOpPffset);
|
|
rcmdBuild.add_view(viewAllOffset);
|
|
rcmdBuild.add_indexes(indexesOffset);
|
|
rcmdBuild.add_iterIndexes(iterIndexesOffset);
|
|
rcmdBuild.add_steps(stepOffset);
|
|
rcmdBuild.add_size(sizeOffset);
|
|
auto rcmdOffset = rcmdBuild.Finish();
|
|
auto rcmdAllOffset = builder.CreateVector<flatbuffers::Offset<RegionCommand>>({rcmdOffset});
|
|
auto inputIndexesOffset = builder.CreateVector(std::vector<int>{0, 1});
|
|
auto outputIndexesOffset = builder.CreateVector(std::vector<int>{2});
|
|
LoopParamBuilder loopBuilder(builder);
|
|
loopBuilder.add_commands(rcmdAllOffset);
|
|
loopBuilder.add_loopNumber(indices->elementSize());
|
|
loopBuilder.add_tensorNumber(3);
|
|
loopBuilder.add_inputIndexes(inputIndexesOffset);
|
|
loopBuilder.add_outputIndexes(outputIndexesOffset);
|
|
auto loopOffset = loopBuilder.Finish();
|
|
flatbuffers::Offset<flatbuffers::String> nameOffset;
|
|
if (nullptr != op->name()) {
|
|
nameOffset = builder.CreateString(op->name()->c_str());
|
|
}
|
|
OpBuilder finishBuilder(builder);
|
|
finishBuilder.add_main(loopOffset.Union());
|
|
finishBuilder.add_main_type(OpParameter_LoopParam);
|
|
finishBuilder.add_type(OpType_While);
|
|
if (nullptr != op->name()) {
|
|
finishBuilder.add_name(nameOffset);
|
|
}
|
|
builder.Finish(finishBuilder.Finish());
|
|
auto cmd = GeometryComputerUtils::makeCommand(builder, {params, indices}, outputs);
|
|
TensorUtils::getDescribe(output)->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND;
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
|
|
class GeometryGather : public GeometryComputer {
|
|
public:
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& res) const override {
|
|
_computeGather(inputs, outputs, context, res, op);
|
|
return true;
|
|
}
|
|
virtual bool onRecompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& cmd) const override {
|
|
if (cmd.command.size() != 1) {
|
|
return false;
|
|
}
|
|
int axis = 0;
|
|
if (inputs.size() == 3) {
|
|
const Tensor *axisTensor = inputs[2];
|
|
axis = axisTensor->host<int32_t>()[0];
|
|
}
|
|
if (op->main_type() == OpParameter_Axis) {
|
|
axis = op->main_as_Axis()->axis();
|
|
}
|
|
auto params = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
MNN_ASSERT(axis > -params->buffer().dimensions && axis < params->buffer().dimensions);
|
|
if (axis < 0) {
|
|
axis = params->buffer().dimensions + axis;
|
|
}
|
|
const int gatherDimSize = params->buffer().dim[axis].extent;
|
|
const int N = indices->elementSize();
|
|
MNN_ASSERT(gatherDimSize <= std::numeric_limits<int32_t>::max());
|
|
|
|
int inside = 1;
|
|
int outside = 1;
|
|
for (int i = 0; i < axis; ++i) {
|
|
outside *= params->length(i);
|
|
}
|
|
for (int i = axis + 1; i < params->dimensions(); ++i) {
|
|
inside *= params->length(i);
|
|
}
|
|
auto loopCmd = cmd.command[0];
|
|
auto param = loopCmd->op->main_as_LoopParam();
|
|
// Reset parameters for last command
|
|
((flatbuffers::Table*)param)->SetField(LoopParam::VT_LOOPNUMBER, indices->elementSize(), 0);
|
|
auto rgcmd = param->commands()->GetAs<RegionCommand>(0);
|
|
auto step = (int*)rgcmd->steps()->data();
|
|
step[0] = inside;
|
|
step[1] = inside;
|
|
auto size = (int*)rgcmd->size()->data();
|
|
size[0] = outside;
|
|
size[2] = inside;
|
|
auto view0Stride = (int*)rgcmd->view()->GetAs<View>(0)->stride()->data();
|
|
view0Stride[0] = inside * N;
|
|
view0Stride[1] = inside;
|
|
auto view1Stride = (int*)rgcmd->view()->GetAs<View>(1)->stride()->data();
|
|
view1Stride[0] = inside * params->length(axis);
|
|
view1Stride[1] = inside;
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class GeometryGatherND : public GeometryComputer {
|
|
public:
|
|
enum MID_POSITION {
|
|
P_constStride = 0,
|
|
P_reshapeIndice = 1,
|
|
P_broadcastStride = 2,
|
|
P_mulIndice = 3,
|
|
P_indiceOneLine = 4,
|
|
P_MAX
|
|
};
|
|
static bool buildGatherND(const Op* op, Tensor* params, Tensor* indice, Tensor* output,
|
|
int N, int D, int S, Context& context, CommandBuffer& res, int B) {
|
|
int paramSize = 1;
|
|
for (int i=B; i<params->dimensions(); ++i) {
|
|
paramSize *= params->length(i);
|
|
}
|
|
std::array<std::shared_ptr<Tensor>, 5> midTensors;
|
|
std::shared_ptr<Tensor> constStride(Tensor::createDevice<int>({D}));
|
|
if (!context.allocTensor(constStride.get())) {
|
|
return false;
|
|
}
|
|
midTensors[P_constStride] = constStride;
|
|
for (int i=0; i<D; ++i) {
|
|
int dimCount = paramSize / params->length(i + B);
|
|
constStride->host<int>()[i] = dimCount;
|
|
paramSize = dimCount;
|
|
}
|
|
std::shared_ptr<Tensor> reshapeIndice(Tensor::createDevice<int>({N, D}));
|
|
midTensors[P_reshapeIndice] = reshapeIndice;
|
|
{
|
|
auto des = TensorUtils::getDescribe(reshapeIndice.get());
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->regions = {GeometryComputerUtils::makeRawAddressRef(indice, 0, N * D)};
|
|
}
|
|
std::shared_ptr<Tensor> broadcastStride(Tensor::createDevice<int>({N, D}));
|
|
midTensors[P_broadcastStride] = broadcastStride;
|
|
{
|
|
// [D] => [N, D]
|
|
auto des = TensorUtils::getDescribe(broadcastStride.get());
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->regions.resize(1);
|
|
des->regions[0].origin = constStride.get();
|
|
des->regions[0].size[0] = 1;
|
|
des->regions[0].size[1] = N;
|
|
des->regions[0].size[2] = D;
|
|
des->regions[0].dst.stride[0] = D*N;
|
|
des->regions[0].dst.stride[1] = D;
|
|
des->regions[0].dst.stride[2] = 1;
|
|
des->regions[0].src.stride[0] = 0;
|
|
des->regions[0].src.stride[1] = 0;
|
|
des->regions[0].src.stride[2] = 1;
|
|
}
|
|
std::shared_ptr<Tensor> mulIndice(Tensor::createDevice<int>({N, D}));
|
|
midTensors[P_mulIndice] = mulIndice;
|
|
{
|
|
// [N, D] * [N, D] => [N, D]
|
|
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, reshapeIndice.get(), broadcastStride.get(), mulIndice.get());
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
std::shared_ptr<Tensor> indiceOneLine(Tensor::createDevice<int>({N, 1}));
|
|
midTensors[P_indiceOneLine] = indiceOneLine;
|
|
{
|
|
// [N, D] => [N, 1]
|
|
auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, mulIndice.get(), indiceOneLine.get());
|
|
res.command.emplace_back(std::move(cmd));
|
|
}
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
OpBuilder unaryOp(builder);
|
|
unaryOp.add_type(OpType_UnaryOp);
|
|
auto unaryOpPffset = unaryOp.Finish();
|
|
auto iterIndexesOffset = builder.CreateVector(std::vector<int>{-1, 1});
|
|
auto stepOffset = builder.CreateVector(std::vector<int>{S, 1});
|
|
auto indexesOffset = builder.CreateVector(std::vector<int>{2, 0});
|
|
auto sizeOffset = builder.CreateVector(std::vector<int>{1, 1, S});
|
|
// View 0
|
|
auto view0Stride = builder.CreateVector(std::vector<int>{S, S, 1});
|
|
ViewBuilder view0Builder(builder);
|
|
view0Builder.add_offset(0);
|
|
view0Builder.add_stride(view0Stride);
|
|
auto view0Offset = view0Builder.Finish();
|
|
// view0 and view1 is the same
|
|
auto viewAllOffset = builder.CreateVector<flatbuffers::Offset<View>>({view0Offset, view0Offset});
|
|
|
|
RegionCommandBuilder rcmdBuild(builder);
|
|
rcmdBuild.add_op(unaryOpPffset);
|
|
rcmdBuild.add_view(viewAllOffset);
|
|
rcmdBuild.add_indexes(indexesOffset);
|
|
rcmdBuild.add_iterIndexes(iterIndexesOffset);
|
|
rcmdBuild.add_steps(stepOffset);
|
|
rcmdBuild.add_size(sizeOffset);
|
|
auto rcmdOffset = rcmdBuild.Finish();
|
|
auto rcmdAllOffset = builder.CreateVector<flatbuffers::Offset<RegionCommand>>({rcmdOffset});
|
|
auto inputIndexesOffset = builder.CreateVector(std::vector<int>{0, 1});
|
|
auto outputIndexesOffset = builder.CreateVector(std::vector<int>{2});
|
|
LoopParamBuilder loopBuilder(builder);
|
|
loopBuilder.add_commands(rcmdAllOffset);
|
|
loopBuilder.add_loopNumber(N);
|
|
loopBuilder.add_tensorNumber(3);
|
|
loopBuilder.add_inputIndexes(inputIndexesOffset);
|
|
loopBuilder.add_outputIndexes(outputIndexesOffset);
|
|
auto loopOffset = loopBuilder.Finish();
|
|
flatbuffers::Offset<flatbuffers::String> nameOffset;
|
|
if (nullptr != op->name()) {
|
|
nameOffset = builder.CreateString(op->name()->c_str());
|
|
}
|
|
OpBuilder finishBuilder(builder);
|
|
finishBuilder.add_main(loopOffset.Union());
|
|
finishBuilder.add_main_type(OpParameter_LoopParam);
|
|
finishBuilder.add_type(OpType_While);
|
|
if (nullptr != op->name()) {
|
|
finishBuilder.add_name(nameOffset);
|
|
}
|
|
builder.Finish(finishBuilder.Finish());
|
|
auto cmd = GeometryComputerUtils::makeCommand(builder, {params, indiceOneLine.get()}, {output});
|
|
TensorUtils::getDescribe(output)->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND;
|
|
res.command.emplace_back(std::move(cmd));
|
|
res.extras.insert(res.extras.end(), midTensors.begin(), midTensors.end());
|
|
return true;
|
|
}
|
|
virtual bool onRecompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& cmd) const override {
|
|
if (cmd.extras.size() != P_MAX) {
|
|
return false;
|
|
}
|
|
MNN_ASSERT(2 == inputs.size());
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto params = inputs[0];
|
|
auto indice = inputs[1];
|
|
auto output = outputs[0];
|
|
int mSliceN = 1;
|
|
int mSliceSize = 1;
|
|
int batchDim = 0;
|
|
if (nullptr != op->main_as_Axis()) {
|
|
batchDim = op->main_as_Axis()->axis();
|
|
}
|
|
|
|
for (int i = 0; i < indice->dimensions() - 1; ++i) {
|
|
mSliceN *= indice->length(i);
|
|
}
|
|
auto indiceNd = indice->length(indice->dimensions() - 1);
|
|
for (int i = indiceNd + batchDim; i < params->dimensions(); ++i) {
|
|
mSliceSize *= params->length(i);
|
|
}
|
|
int paramSize = 1;
|
|
for (int i=batchDim; i<params->dimensions(); ++i) {
|
|
paramSize *= params->length(i);
|
|
}
|
|
auto constStride = cmd.extras[P_constStride];
|
|
auto reshapeIndice = cmd.extras[P_reshapeIndice];
|
|
auto broadcastStride = cmd.extras[P_broadcastStride];
|
|
auto mulIndice = cmd.extras[P_mulIndice];
|
|
auto indiceOneLine = cmd.extras[P_indiceOneLine];
|
|
// Set length
|
|
bool needAlloc = constStride->length(0) < indiceNd;
|
|
constStride->setLength(0, indiceNd);
|
|
reshapeIndice->setLength(0, mSliceN);
|
|
reshapeIndice->setLength(1, indiceNd);
|
|
broadcastStride->setLength(0, mSliceN);
|
|
broadcastStride->setLength(1, indiceNd);
|
|
mulIndice->setLength(0, mSliceN);
|
|
mulIndice->setLength(1, indiceNd);
|
|
indiceOneLine->setLength(0, mSliceN);
|
|
indiceOneLine->setLength(1, 1);
|
|
|
|
if (needAlloc) {
|
|
if (!context.allocTensor(constStride.get())) {
|
|
return false;
|
|
}
|
|
}
|
|
for (int i=0; i<indiceNd; ++i) {
|
|
int dimCount = paramSize / params->length(i + batchDim);
|
|
constStride->host<int>()[i] = dimCount;
|
|
paramSize = dimCount;
|
|
}
|
|
// recompute reshape
|
|
auto des = TensorUtils::getDescribe(reshapeIndice.get());
|
|
des->extra.offset = 0;
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->regions = {GeometryComputerUtils::makeRawAddressRef(indice, 0, mSliceN * indiceNd)};
|
|
// recompute broadcast
|
|
des = TensorUtils::getDescribe(broadcastStride.get());
|
|
des->extra.offset = 0;
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->regions[0].origin = constStride.get();
|
|
des->regions[0].size[0] = 1;
|
|
des->regions[0].size[1] = mSliceN;
|
|
des->regions[0].size[2] = indiceNd;
|
|
des->regions[0].dst.stride[0] = indiceNd*mSliceN;
|
|
des->regions[0].dst.stride[1] = indiceNd;
|
|
des->regions[0].dst.stride[2] = 1;
|
|
// recompute loop
|
|
auto loopCmd = cmd.command[cmd.command.size() - 1];
|
|
auto param = loopCmd->op->main_as_LoopParam();
|
|
// Reset parameters for last command
|
|
((flatbuffers::Table*)param)->SetField(LoopParam::VT_LOOPNUMBER, mSliceN, 0);
|
|
auto rgCmd = param->commands()->GetAs<RegionCommand>(0);
|
|
auto stepData = (int*)rgCmd->steps()->data();
|
|
stepData[0] = mSliceSize;
|
|
auto sizeData = (int*)rgCmd->size()->data();
|
|
sizeData[2] = mSliceSize;
|
|
auto strideData = (int*)rgCmd->view()->GetAs<View>(0)->stride()->data();
|
|
strideData[0] = mSliceSize;
|
|
strideData[1] = mSliceSize;
|
|
strideData = (int*)rgCmd->view()->GetAs<View>(1)->stride()->data();
|
|
strideData[0] = mSliceSize;
|
|
strideData[1] = mSliceSize;
|
|
return true;
|
|
}
|
|
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& res) const override {
|
|
MNN_ASSERT(2 == inputs.size());
|
|
MNN_ASSERT(1 == outputs.size());
|
|
auto params = inputs[0];
|
|
auto indice = inputs[1];
|
|
auto output = outputs[0];
|
|
int batchDim = 0;
|
|
if (nullptr != op->main_as_Axis()) {
|
|
batchDim = op->main_as_Axis()->axis();
|
|
}
|
|
|
|
int N = 1;
|
|
int S = 1;
|
|
for (int i = 0; i < indice->dimensions() - 1; ++i) {
|
|
N *= indice->length(i);
|
|
}
|
|
auto D = indice->length(indice->dimensions() - 1);
|
|
for (int i = D + batchDim; i < params->dimensions(); ++i) {
|
|
S *= params->length(i);
|
|
}
|
|
return buildGatherND(op, params, indice, output, N, D, S, context, res, batchDim);
|
|
}
|
|
};
|
|
|
|
class GeometryGatherElements : public GeometryComputer {
|
|
public:
|
|
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
Context& context, CommandBuffer& res) const override {
|
|
auto data = inputs[0];
|
|
auto indices = inputs[1];
|
|
auto output = outputs[0];
|
|
int axis = 0;
|
|
if (inputs.size() >= 3) {
|
|
auto axisTensor = inputs[2];
|
|
axis = axisTensor->host<int>()[0];
|
|
}
|
|
auto D = data->buffer().dimensions;
|
|
auto N = indices->elementSize();
|
|
if (axis < 0) {
|
|
axis = D + axis;
|
|
}
|
|
// flatten indices/update
|
|
std::shared_ptr<Tensor> flattenIndice(Tensor::createDevice<int>({N}));
|
|
{
|
|
auto ides = TensorUtils::getDescribe(flattenIndice.get());
|
|
ides->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
ides->regions = {GeometryComputerUtils::makeRawAddressRef(indices, 0, N)};
|
|
res.extras.emplace_back(flattenIndice);
|
|
}
|
|
// reindex
|
|
std::shared_ptr<Tensor> newIndice(Tensor::createDevice<int>({N, D}));
|
|
{
|
|
auto des = TensorUtils::getDescribe(newIndice.get());
|
|
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
|
|
des->regions.resize(D);
|
|
for (int i = 0; i < D; i++) {
|
|
if (i == axis) {
|
|
des->regions[i].origin = flattenIndice.get();
|
|
} else {
|
|
int inner = 1, outter = 1, middle = indices->shape()[i];
|
|
for (int j = 0; j < i; j++) outter *= indices->shape()[j];
|
|
for (int j = i + 1; j < D; j++) inner *= indices->shape()[j];
|
|
MNN_ASSERT(N == inner * middle * outter);
|
|
auto subIndice = context.allocConst(op, {N}, halide_type_of<int>());
|
|
auto ptr = subIndice->host<int>();
|
|
int idx = 0;
|
|
for (int out = 0; out < outter; out++) {
|
|
for (int mid = 0; mid < middle; mid++) {
|
|
for (int in = 0; in < inner; in++) {
|
|
ptr[idx++] = mid;
|
|
}
|
|
}
|
|
}
|
|
des->regions[i].origin = subIndice.get();
|
|
}
|
|
des->regions[i].size[2] = N;
|
|
des->regions[i].dst.stride[2] = D;
|
|
des->regions[i].dst.offset = i;
|
|
}
|
|
res.extras.emplace_back(newIndice);
|
|
}
|
|
return GeometryGatherND::buildGatherND(op, data, newIndice.get(), output, N, D, 1, context, res, 0);
|
|
}
|
|
};
|
|
|
|
static void _create() {
|
|
std::shared_ptr<GeometryComputer> comp(new GeometryGather);
|
|
GeometryComputer::registerGeometryComputer(comp, {OpType_Gather, OpType_GatherV2}, Runtime::Compiler_Loop);
|
|
std::shared_ptr<GeometryComputer> comp2(new GeometryGatherND);
|
|
GeometryComputer::registerGeometryComputer(comp2, {OpType_GatherND}, Runtime::Compiler_Loop);
|
|
std::shared_ptr<GeometryComputer> comp3(new GeometryGatherElements);
|
|
GeometryComputer::registerGeometryComputer(comp3, {OpType_GatherElements}, Runtime::Compiler_Loop);
|
|
}
|
|
|
|
REGISTER_GEOMETRY(GeometryGather, _create);
|
|
|
|
} // namespace MNN
|