| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  GeometryGather.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2020/06/09.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "geometry/GeometryComputer.hpp"
 | 
					
						
							|  |  |  | #include "geometry/GeometryComputerUtils.hpp"
 | 
					
						
							|  |  |  | #include "core/OpCommonUtils.hpp"
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | static void _computeGather(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                            GeometryComputer::Context& context, CommandBuffer& res, const Op* op) { | 
					
						
							|  |  |  |     int axis = 0; | 
					
						
							|  |  |  |     if (inputs.size() == 3) { | 
					
						
							|  |  |  |         const Tensor *axisTensor = inputs[2]; | 
					
						
							|  |  |  |         axis                     = axisTensor->host<int32_t>()[0]; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (op->main_type() == OpParameter_Axis) { | 
					
						
							|  |  |  |         axis = op->main_as_Axis()->axis(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     auto params  = inputs[0]; | 
					
						
							|  |  |  |     auto indices = inputs[1]; | 
					
						
							|  |  |  |     auto output  = outputs[0]; | 
					
						
							|  |  |  |     MNN_ASSERT(axis > -params->buffer().dimensions && axis < params->buffer().dimensions); | 
					
						
							|  |  |  |     if (axis < 0) { | 
					
						
							|  |  |  |         axis = params->buffer().dimensions + axis; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     const int gatherDimSize = params->buffer().dim[axis].extent; | 
					
						
							|  |  |  |     const int N             = indices->elementSize(); | 
					
						
							|  |  |  |     MNN_ASSERT(gatherDimSize <= std::numeric_limits<int32_t>::max()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int inside  = 1; | 
					
						
							|  |  |  |     int outside = 1; | 
					
						
							|  |  |  |     for (int i = 0; i < axis; ++i) { | 
					
						
							|  |  |  |         outside *= params->length(i); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (int i = axis + 1; i < params->dimensions(); ++i) { | 
					
						
							|  |  |  |         inside *= params->length(i); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2024-04-19 11:58:21 +08:00
										 |  |  |     const int limit = 3; | 
					
						
							|  |  |  |     if (TensorUtils::getDescribe(indices)->usage == Tensor::InsideDescribe::CONSTANT && N < limit) { | 
					
						
							|  |  |  |         // Use Raster instead of loop
 | 
					
						
							|  |  |  |         auto outDes = TensorUtils::getDescribe(output); | 
					
						
							|  |  |  |         outDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |         outDes->regions.clear(); | 
					
						
							|  |  |  |         outDes->regions.reserve(N); | 
					
						
							|  |  |  |         auto indicePtr = indices->host<int>(); | 
					
						
							|  |  |  |         auto axisLen = params->length(axis); | 
					
						
							|  |  |  |         for (int i=0; i<N; ++i) { | 
					
						
							|  |  |  |             if (indicePtr[i] < 0 || indicePtr[i] >= axisLen) { | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             Tensor::InsideDescribe::Region reg; | 
					
						
							|  |  |  |             reg.origin = inputs[0]; | 
					
						
							|  |  |  |             reg.size[0] = 1; | 
					
						
							|  |  |  |             reg.size[1] = outside; | 
					
						
							|  |  |  |             reg.size[2] = inside; | 
					
						
							|  |  |  |             reg.src.offset = indicePtr[i] * inside; | 
					
						
							|  |  |  |             reg.src.stride[0] = 0; | 
					
						
							|  |  |  |             reg.src.stride[1] = inside * axisLen; | 
					
						
							|  |  |  |             reg.src.stride[2] = 1; | 
					
						
							|  |  |  |             reg.dst.offset = i * inside; | 
					
						
							|  |  |  |             reg.dst.stride[1] = inside * N; | 
					
						
							|  |  |  |             reg.dst.stride[2] = 1; | 
					
						
							|  |  |  |             outDes->regions.emplace_back(std::move(reg)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     flatbuffers::FlatBufferBuilder builder; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     OpBuilder unaryOp(builder); | 
					
						
							|  |  |  |     unaryOp.add_type(OpType_UnaryOp); | 
					
						
							|  |  |  |     auto unaryOpPffset = unaryOp.Finish(); | 
					
						
							|  |  |  |     auto iterIndexesOffset = builder.CreateVector(std::vector<int>{-1, 1}); | 
					
						
							|  |  |  |     auto stepOffset = builder.CreateVector(std::vector<int>{inside, inside}); | 
					
						
							|  |  |  |     auto indexesOffset = builder.CreateVector(std::vector<int>{2, 0}); | 
					
						
							|  |  |  |     auto sizeOffset = builder.CreateVector(std::vector<int>{outside, 1, inside}); | 
					
						
							|  |  |  |     // View 0
 | 
					
						
							|  |  |  |     auto view0Stride = builder.CreateVector(std::vector<int>{inside * N, inside, 1}); | 
					
						
							|  |  |  |     ViewBuilder view0Builder(builder); | 
					
						
							|  |  |  |     view0Builder.add_offset(0); | 
					
						
							|  |  |  |     view0Builder.add_stride(view0Stride); | 
					
						
							|  |  |  |     auto view0Offset = view0Builder.Finish(); | 
					
						
							|  |  |  |     // View1
 | 
					
						
							|  |  |  |     auto view1Stride = builder.CreateVector(std::vector<int>{inside * params->length(axis), inside, 1}); | 
					
						
							|  |  |  |     ViewBuilder view1Builder(builder); | 
					
						
							|  |  |  |     view1Builder.add_offset(0); | 
					
						
							|  |  |  |     view1Builder.add_stride(view1Stride); | 
					
						
							|  |  |  |     auto view1Offset = view1Builder.Finish(); | 
					
						
							|  |  |  |     auto viewAllOffset = builder.CreateVector<flatbuffers::Offset<View>>({view0Offset, view1Offset}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     RegionCommandBuilder rcmdBuild(builder); | 
					
						
							|  |  |  |     rcmdBuild.add_op(unaryOpPffset); | 
					
						
							|  |  |  |     rcmdBuild.add_view(viewAllOffset); | 
					
						
							|  |  |  |     rcmdBuild.add_indexes(indexesOffset); | 
					
						
							|  |  |  |     rcmdBuild.add_iterIndexes(iterIndexesOffset); | 
					
						
							|  |  |  |     rcmdBuild.add_steps(stepOffset); | 
					
						
							|  |  |  |     rcmdBuild.add_size(sizeOffset); | 
					
						
							|  |  |  |     auto rcmdOffset = rcmdBuild.Finish(); | 
					
						
							|  |  |  |     auto rcmdAllOffset = builder.CreateVector<flatbuffers::Offset<RegionCommand>>({rcmdOffset}); | 
					
						
							|  |  |  |     auto inputIndexesOffset = builder.CreateVector(std::vector<int>{0, 1}); | 
					
						
							|  |  |  |     auto outputIndexesOffset = builder.CreateVector(std::vector<int>{2}); | 
					
						
							|  |  |  |     LoopParamBuilder loopBuilder(builder); | 
					
						
							|  |  |  |     loopBuilder.add_commands(rcmdAllOffset); | 
					
						
							|  |  |  |     loopBuilder.add_loopNumber(indices->elementSize()); | 
					
						
							|  |  |  |     loopBuilder.add_tensorNumber(3); | 
					
						
							|  |  |  |     loopBuilder.add_inputIndexes(inputIndexesOffset); | 
					
						
							|  |  |  |     loopBuilder.add_outputIndexes(outputIndexesOffset); | 
					
						
							|  |  |  |     auto loopOffset = loopBuilder.Finish(); | 
					
						
							|  |  |  |     flatbuffers::Offset<flatbuffers::String> nameOffset; | 
					
						
							|  |  |  |     if (nullptr != op->name()) { | 
					
						
							|  |  |  |         nameOffset = builder.CreateString(op->name()->c_str()); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     OpBuilder finishBuilder(builder); | 
					
						
							|  |  |  |     finishBuilder.add_main(loopOffset.Union()); | 
					
						
							|  |  |  |     finishBuilder.add_main_type(OpParameter_LoopParam); | 
					
						
							|  |  |  |     finishBuilder.add_type(OpType_While); | 
					
						
							|  |  |  |     if (nullptr != op->name()) { | 
					
						
							|  |  |  |         finishBuilder.add_name(nameOffset); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     builder.Finish(finishBuilder.Finish()); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     auto cmd = GeometryComputerUtils::makeCommand(builder, {params, indices}, outputs); | 
					
						
							|  |  |  |     TensorUtils::getDescribe(output)->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND; | 
					
						
							|  |  |  |     res.command.emplace_back(std::move(cmd)); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | class GeometryGather : public GeometryComputer { | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                            Context& context, CommandBuffer& res) const override { | 
					
						
							|  |  |  |         _computeGather(inputs, outputs, context, res, op); | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     virtual bool onRecompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                              Context& context, CommandBuffer& cmd) const override { | 
					
						
							|  |  |  |         if (cmd.command.size() != 1) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         int axis = 0; | 
					
						
							|  |  |  |         if (inputs.size() == 3) { | 
					
						
							|  |  |  |             const Tensor *axisTensor = inputs[2]; | 
					
						
							|  |  |  |             axis                     = axisTensor->host<int32_t>()[0]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (op->main_type() == OpParameter_Axis) { | 
					
						
							|  |  |  |             axis = op->main_as_Axis()->axis(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto params  = inputs[0]; | 
					
						
							|  |  |  |         auto indices = inputs[1]; | 
					
						
							|  |  |  |         auto output  = outputs[0]; | 
					
						
							|  |  |  |         MNN_ASSERT(axis > -params->buffer().dimensions && axis < params->buffer().dimensions); | 
					
						
							|  |  |  |         if (axis < 0) { | 
					
						
							|  |  |  |             axis = params->buffer().dimensions + axis; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         const int gatherDimSize = params->buffer().dim[axis].extent; | 
					
						
							|  |  |  |         const int N             = indices->elementSize(); | 
					
						
							|  |  |  |         MNN_ASSERT(gatherDimSize <= std::numeric_limits<int32_t>::max()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int inside  = 1; | 
					
						
							|  |  |  |         int outside = 1; | 
					
						
							|  |  |  |         for (int i = 0; i < axis; ++i) { | 
					
						
							|  |  |  |             outside *= params->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = axis + 1; i < params->dimensions(); ++i) { | 
					
						
							|  |  |  |             inside *= params->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto loopCmd = cmd.command[0]; | 
					
						
							|  |  |  |         auto param = loopCmd->op->main_as_LoopParam(); | 
					
						
							|  |  |  |         // Reset parameters for last command
 | 
					
						
							|  |  |  |         ((flatbuffers::Table*)param)->SetField(LoopParam::VT_LOOPNUMBER, indices->elementSize(), 0); | 
					
						
							|  |  |  |         auto rgcmd = param->commands()->GetAs<RegionCommand>(0); | 
					
						
							|  |  |  |         auto step = (int*)rgcmd->steps()->data(); | 
					
						
							|  |  |  |         step[0] = inside; | 
					
						
							|  |  |  |         step[1] = inside; | 
					
						
							|  |  |  |         auto size = (int*)rgcmd->size()->data(); | 
					
						
							|  |  |  |         size[0] = outside; | 
					
						
							|  |  |  |         size[2] = inside; | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  |         auto view0Stride = (int*)rgcmd->view()->GetAs<View>(0)->stride()->data(); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         view0Stride[0] = inside * N; | 
					
						
							|  |  |  |         view0Stride[1] = inside; | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  |         auto view1Stride = (int*)rgcmd->view()->GetAs<View>(1)->stride()->data(); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         view1Stride[0] = inside * params->length(axis); | 
					
						
							|  |  |  |         view1Stride[1] = inside; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | class GeometryGatherND : public GeometryComputer { | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     enum MID_POSITION { | 
					
						
							|  |  |  |         P_constStride = 0, | 
					
						
							|  |  |  |         P_reshapeIndice = 1, | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |         P_broadcastStride = 2, | 
					
						
							|  |  |  |         P_mulIndice = 3, | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         P_indiceOneLine = 4, | 
					
						
							|  |  |  |         P_MAX | 
					
						
							|  |  |  |     }; | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |     static bool buildGatherND(const Op* op, Tensor* params, Tensor* indice, Tensor* output, | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                               int N, int D, int S, Context& context, CommandBuffer& res, int B) { | 
					
						
							|  |  |  |         int paramSize = 1; | 
					
						
							|  |  |  |         for (int i=B; i<params->dimensions(); ++i) { | 
					
						
							|  |  |  |             paramSize *= params->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         std::array<std::shared_ptr<Tensor>, 5> midTensors; | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> constStride(Tensor::createDevice<int>({D})); | 
					
						
							|  |  |  |         if (!context.allocTensor(constStride.get())) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         midTensors[P_constStride] = constStride; | 
					
						
							|  |  |  |         for (int i=0; i<D; ++i) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |             int dimCount = paramSize / params->length(i + B); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |             constStride->host<int>()[i] = dimCount; | 
					
						
							|  |  |  |             paramSize = dimCount; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> reshapeIndice(Tensor::createDevice<int>({N, D})); | 
					
						
							|  |  |  |         midTensors[P_reshapeIndice] = reshapeIndice; | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             auto des = TensorUtils::getDescribe(reshapeIndice.get()); | 
					
						
							|  |  |  |             des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |             des->regions = {GeometryComputerUtils::makeRawAddressRef(indice, 0, N * D)}; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> broadcastStride(Tensor::createDevice<int>({N, D})); | 
					
						
							|  |  |  |         midTensors[P_broadcastStride] = broadcastStride; | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             // [D] => [N, D]
 | 
					
						
							|  |  |  |             auto des = TensorUtils::getDescribe(broadcastStride.get()); | 
					
						
							|  |  |  |             des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |             des->regions.resize(1); | 
					
						
							|  |  |  |             des->regions[0].origin = constStride.get(); | 
					
						
							|  |  |  |             des->regions[0].size[0] = 1; | 
					
						
							|  |  |  |             des->regions[0].size[1] = N; | 
					
						
							|  |  |  |             des->regions[0].size[2] = D; | 
					
						
							|  |  |  |             des->regions[0].dst.stride[0] = D*N; | 
					
						
							|  |  |  |             des->regions[0].dst.stride[1] = D; | 
					
						
							|  |  |  |             des->regions[0].dst.stride[2] = 1; | 
					
						
							|  |  |  |             des->regions[0].src.stride[0] = 0; | 
					
						
							|  |  |  |             des->regions[0].src.stride[1] = 0; | 
					
						
							|  |  |  |             des->regions[0].src.stride[2] = 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> mulIndice(Tensor::createDevice<int>({N, D})); | 
					
						
							|  |  |  |         midTensors[P_mulIndice] = mulIndice; | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             // [N, D] * [N, D] => [N, D]
 | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, reshapeIndice.get(), broadcastStride.get(), mulIndice.get()); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> indiceOneLine(Tensor::createDevice<int>({N, 1})); | 
					
						
							|  |  |  |         midTensors[P_indiceOneLine] = indiceOneLine; | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             // [N, D] => [N, 1]
 | 
					
						
							|  |  |  |             auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, mulIndice.get(), indiceOneLine.get()); | 
					
						
							|  |  |  |             res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         flatbuffers::FlatBufferBuilder builder; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         OpBuilder unaryOp(builder); | 
					
						
							|  |  |  |         unaryOp.add_type(OpType_UnaryOp); | 
					
						
							|  |  |  |         auto unaryOpPffset = unaryOp.Finish(); | 
					
						
							|  |  |  |         auto iterIndexesOffset = builder.CreateVector(std::vector<int>{-1, 1}); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         auto stepOffset = builder.CreateVector(std::vector<int>{S, 1}); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         auto indexesOffset = builder.CreateVector(std::vector<int>{2, 0}); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         auto sizeOffset = builder.CreateVector(std::vector<int>{1, 1, S}); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         // View 0
 | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         auto view0Stride = builder.CreateVector(std::vector<int>{S, S, 1}); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         ViewBuilder view0Builder(builder); | 
					
						
							|  |  |  |         view0Builder.add_offset(0); | 
					
						
							|  |  |  |         view0Builder.add_stride(view0Stride); | 
					
						
							|  |  |  |         auto view0Offset = view0Builder.Finish(); | 
					
						
							|  |  |  |         // view0 and view1 is the same
 | 
					
						
							|  |  |  |         auto viewAllOffset = builder.CreateVector<flatbuffers::Offset<View>>({view0Offset, view0Offset}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         RegionCommandBuilder rcmdBuild(builder); | 
					
						
							|  |  |  |         rcmdBuild.add_op(unaryOpPffset); | 
					
						
							|  |  |  |         rcmdBuild.add_view(viewAllOffset); | 
					
						
							|  |  |  |         rcmdBuild.add_indexes(indexesOffset); | 
					
						
							|  |  |  |         rcmdBuild.add_iterIndexes(iterIndexesOffset); | 
					
						
							|  |  |  |         rcmdBuild.add_steps(stepOffset); | 
					
						
							|  |  |  |         rcmdBuild.add_size(sizeOffset); | 
					
						
							|  |  |  |         auto rcmdOffset = rcmdBuild.Finish(); | 
					
						
							|  |  |  |         auto rcmdAllOffset = builder.CreateVector<flatbuffers::Offset<RegionCommand>>({rcmdOffset}); | 
					
						
							|  |  |  |         auto inputIndexesOffset = builder.CreateVector(std::vector<int>{0, 1}); | 
					
						
							|  |  |  |         auto outputIndexesOffset = builder.CreateVector(std::vector<int>{2}); | 
					
						
							|  |  |  |         LoopParamBuilder loopBuilder(builder); | 
					
						
							|  |  |  |         loopBuilder.add_commands(rcmdAllOffset); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         loopBuilder.add_loopNumber(N); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         loopBuilder.add_tensorNumber(3); | 
					
						
							|  |  |  |         loopBuilder.add_inputIndexes(inputIndexesOffset); | 
					
						
							|  |  |  |         loopBuilder.add_outputIndexes(outputIndexesOffset); | 
					
						
							|  |  |  |         auto loopOffset = loopBuilder.Finish(); | 
					
						
							|  |  |  |         flatbuffers::Offset<flatbuffers::String> nameOffset; | 
					
						
							|  |  |  |         if (nullptr != op->name()) { | 
					
						
							|  |  |  |             nameOffset = builder.CreateString(op->name()->c_str()); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         OpBuilder finishBuilder(builder); | 
					
						
							|  |  |  |         finishBuilder.add_main(loopOffset.Union()); | 
					
						
							|  |  |  |         finishBuilder.add_main_type(OpParameter_LoopParam); | 
					
						
							|  |  |  |         finishBuilder.add_type(OpType_While); | 
					
						
							|  |  |  |         if (nullptr != op->name()) { | 
					
						
							|  |  |  |             finishBuilder.add_name(nameOffset); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         builder.Finish(finishBuilder.Finish()); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         auto cmd = GeometryComputerUtils::makeCommand(builder, {params, indiceOneLine.get()}, {output}); | 
					
						
							|  |  |  |         TensorUtils::getDescribe(output)->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND; | 
					
						
							|  |  |  |         res.command.emplace_back(std::move(cmd)); | 
					
						
							|  |  |  |         res.extras.insert(res.extras.end(), midTensors.begin(), midTensors.end()); | 
					
						
							|  |  |  |         return true; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     virtual bool onRecompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                              Context& context, CommandBuffer& cmd) const override { | 
					
						
							|  |  |  |         if (cmd.extras.size() != P_MAX) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         MNN_ASSERT(2 == inputs.size()); | 
					
						
							|  |  |  |         MNN_ASSERT(1 == outputs.size()); | 
					
						
							|  |  |  |         auto params = inputs[0]; | 
					
						
							|  |  |  |         auto indice = inputs[1]; | 
					
						
							|  |  |  |         auto output = outputs[0]; | 
					
						
							|  |  |  |         int mSliceN    = 1; | 
					
						
							|  |  |  |         int mSliceSize = 1; | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         int batchDim = 0; | 
					
						
							|  |  |  |         if (nullptr != op->main_as_Axis()) { | 
					
						
							|  |  |  |             batchDim = op->main_as_Axis()->axis(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         for (int i = 0; i < indice->dimensions() - 1; ++i) { | 
					
						
							|  |  |  |             mSliceN *= indice->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto indiceNd = indice->length(indice->dimensions() - 1); | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         for (int i = indiceNd + batchDim; i < params->dimensions(); ++i) { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |             mSliceSize *= params->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         int paramSize = 1; | 
					
						
							|  |  |  |         for (int i=batchDim; i<params->dimensions(); ++i) { | 
					
						
							|  |  |  |             paramSize *= params->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         auto constStride = cmd.extras[P_constStride]; | 
					
						
							|  |  |  |         auto reshapeIndice = cmd.extras[P_reshapeIndice]; | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |         auto broadcastStride = cmd.extras[P_broadcastStride]; | 
					
						
							|  |  |  |         auto mulIndice = cmd.extras[P_mulIndice]; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         auto indiceOneLine = cmd.extras[P_indiceOneLine]; | 
					
						
							|  |  |  |         // Set length
 | 
					
						
							|  |  |  |         bool needAlloc = constStride->length(0) < indiceNd; | 
					
						
							|  |  |  |         constStride->setLength(0, indiceNd); | 
					
						
							|  |  |  |         reshapeIndice->setLength(0, mSliceN); | 
					
						
							|  |  |  |         reshapeIndice->setLength(1, indiceNd); | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |         broadcastStride->setLength(0, mSliceN); | 
					
						
							|  |  |  |         broadcastStride->setLength(1, indiceNd); | 
					
						
							|  |  |  |         mulIndice->setLength(0, mSliceN); | 
					
						
							|  |  |  |         mulIndice->setLength(1, indiceNd); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         indiceOneLine->setLength(0, mSliceN); | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |         indiceOneLine->setLength(1, 1); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if (needAlloc) { | 
					
						
							|  |  |  |             if (!context.allocTensor(constStride.get())) { | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i=0; i<indiceNd; ++i) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |             int dimCount = paramSize / params->length(i + batchDim); | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |             constStride->host<int>()[i] = dimCount; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |             paramSize = dimCount; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |         // recompute reshape
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         auto des = TensorUtils::getDescribe(reshapeIndice.get()); | 
					
						
							|  |  |  |         des->extra.offset = 0; | 
					
						
							|  |  |  |         des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |         des->regions = {GeometryComputerUtils::makeRawAddressRef(indice, 0, mSliceN * indiceNd)}; | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |         // recompute broadcast
 | 
					
						
							|  |  |  |         des = TensorUtils::getDescribe(broadcastStride.get()); | 
					
						
							|  |  |  |         des->extra.offset = 0; | 
					
						
							|  |  |  |         des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |         des->regions[0].origin = constStride.get(); | 
					
						
							|  |  |  |         des->regions[0].size[0] = 1; | 
					
						
							|  |  |  |         des->regions[0].size[1] = mSliceN; | 
					
						
							|  |  |  |         des->regions[0].size[2] = indiceNd; | 
					
						
							|  |  |  |         des->regions[0].dst.stride[0] = indiceNd*mSliceN; | 
					
						
							|  |  |  |         des->regions[0].dst.stride[1] = indiceNd; | 
					
						
							|  |  |  |         des->regions[0].dst.stride[2] = 1; | 
					
						
							|  |  |  |         // recompute loop
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         auto loopCmd = cmd.command[cmd.command.size() - 1]; | 
					
						
							|  |  |  |         auto param = loopCmd->op->main_as_LoopParam(); | 
					
						
							|  |  |  |         // Reset parameters for last command
 | 
					
						
							|  |  |  |         ((flatbuffers::Table*)param)->SetField(LoopParam::VT_LOOPNUMBER, mSliceN, 0); | 
					
						
							|  |  |  |         auto rgCmd = param->commands()->GetAs<RegionCommand>(0); | 
					
						
							|  |  |  |         auto stepData = (int*)rgCmd->steps()->data(); | 
					
						
							|  |  |  |         stepData[0] = mSliceSize; | 
					
						
							|  |  |  |         auto sizeData = (int*)rgCmd->size()->data(); | 
					
						
							|  |  |  |         sizeData[2] = mSliceSize; | 
					
						
							|  |  |  |         auto strideData = (int*)rgCmd->view()->GetAs<View>(0)->stride()->data(); | 
					
						
							|  |  |  |         strideData[0] = mSliceSize; | 
					
						
							|  |  |  |         strideData[1] = mSliceSize; | 
					
						
							|  |  |  |         strideData = (int*)rgCmd->view()->GetAs<View>(1)->stride()->data(); | 
					
						
							|  |  |  |         strideData[0] = mSliceSize; | 
					
						
							|  |  |  |         strideData[1] = mSliceSize; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                            Context& context, CommandBuffer& res) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(2 == inputs.size()); | 
					
						
							|  |  |  |         MNN_ASSERT(1 == outputs.size()); | 
					
						
							|  |  |  |         auto params = inputs[0]; | 
					
						
							|  |  |  |         auto indice = inputs[1]; | 
					
						
							|  |  |  |         auto output = outputs[0]; | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         int batchDim = 0; | 
					
						
							|  |  |  |         if (nullptr != op->main_as_Axis()) { | 
					
						
							|  |  |  |             batchDim = op->main_as_Axis()->axis(); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         int N = 1; | 
					
						
							|  |  |  |         int S = 1; | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |         for (int i = 0; i < indice->dimensions() - 1; ++i) { | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |             N *= indice->length(i); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         auto D = indice->length(indice->dimensions() - 1); | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         for (int i = D + batchDim; i < params->dimensions(); ++i) { | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |             S *= params->length(i); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         return buildGatherND(op, params, indice, output, N, D, S, context, res, batchDim); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GeometryGatherElements : public GeometryComputer { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                            Context& context, CommandBuffer& res) const override { | 
					
						
							|  |  |  |         auto data     = inputs[0]; | 
					
						
							|  |  |  |         auto indices  = inputs[1]; | 
					
						
							|  |  |  |         auto output   = outputs[0]; | 
					
						
							|  |  |  |         int axis      = 0; | 
					
						
							|  |  |  |         if (inputs.size() >= 3) { | 
					
						
							|  |  |  |             auto axisTensor = inputs[2]; | 
					
						
							|  |  |  |             axis = axisTensor->host<int>()[0]; | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         auto D  = data->buffer().dimensions; | 
					
						
							|  |  |  |         auto N  = indices->elementSize(); | 
					
						
							|  |  |  |         if (axis < 0) { | 
					
						
							|  |  |  |             axis = D + axis; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         // flatten indices/update
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> flattenIndice(Tensor::createDevice<int>({N})); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         { | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |             auto ides = TensorUtils::getDescribe(flattenIndice.get()); | 
					
						
							|  |  |  |             ides->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |             ides->regions = {GeometryComputerUtils::makeRawAddressRef(indices, 0, N)}; | 
					
						
							|  |  |  |             res.extras.emplace_back(flattenIndice); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |         // reindex
 | 
					
						
							|  |  |  |         std::shared_ptr<Tensor> newIndice(Tensor::createDevice<int>({N, D})); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         { | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |             auto des = TensorUtils::getDescribe(newIndice.get()); | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |             des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |             des->regions.resize(D); | 
					
						
							|  |  |  |             for (int i = 0; i < D; i++) { | 
					
						
							|  |  |  |                 if (i == axis) { | 
					
						
							|  |  |  |                     des->regions[i].origin = flattenIndice.get(); | 
					
						
							|  |  |  |                 } else { | 
					
						
							|  |  |  |                     int inner = 1, outter = 1, middle = indices->shape()[i]; | 
					
						
							|  |  |  |                     for (int j = 0; j < i; j++) outter *= indices->shape()[j]; | 
					
						
							|  |  |  |                     for (int j = i + 1; j < D; j++) inner *= indices->shape()[j]; | 
					
						
							|  |  |  |                     MNN_ASSERT(N == inner * middle * outter); | 
					
						
							|  |  |  |                     auto subIndice = context.allocConst(op, {N}, halide_type_of<int>()); | 
					
						
							|  |  |  |                     auto ptr = subIndice->host<int>(); | 
					
						
							|  |  |  |                     int idx = 0; | 
					
						
							|  |  |  |                     for (int out = 0; out < outter; out++) { | 
					
						
							|  |  |  |                         for (int mid = 0; mid < middle; mid++) { | 
					
						
							|  |  |  |                             for (int in = 0; in < inner; in++) { | 
					
						
							|  |  |  |                                 ptr[idx++] = mid; | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     des->regions[i].origin = subIndice.get(); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 des->regions[i].size[2] = N; | 
					
						
							|  |  |  |                 des->regions[i].dst.stride[2] = D; | 
					
						
							|  |  |  |                 des->regions[i].dst.offset = i; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             res.extras.emplace_back(newIndice); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         return GeometryGatherND::buildGatherND(op, data, newIndice.get(), output, N, D, 1, context, res, 0); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | static void _create() { | 
					
						
							|  |  |  |     std::shared_ptr<GeometryComputer> comp(new GeometryGather); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     GeometryComputer::registerGeometryComputer(comp, {OpType_Gather, OpType_GatherV2}, Runtime::Compiler_Loop); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     std::shared_ptr<GeometryComputer> comp2(new GeometryGatherND); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     GeometryComputer::registerGeometryComputer(comp2, {OpType_GatherND}, Runtime::Compiler_Loop); | 
					
						
							| 
									
										
										
										
											2022-07-11 10:56:37 +08:00
										 |  |  |     std::shared_ptr<GeometryComputer> comp3(new GeometryGatherElements); | 
					
						
							|  |  |  |     GeometryComputer::registerGeometryComputer(comp3, {OpType_GatherElements}, Runtime::Compiler_Loop); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_GEOMETRY(GeometryGather, _create); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 |