| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ShapeTensorArray.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2020/12/21.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "shape/SizeComputer.hpp"
 | 
					
						
							|  |  |  | #include "core/Macro.h"
 | 
					
						
							|  |  |  | #include "math.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | static void copyTensorArrayAttribute(const Tensor* src, Tensor* dst) { | 
					
						
							|  |  |  |     auto srcDes = TensorUtils::getDescribe(src); | 
					
						
							|  |  |  |     auto dstDes = TensorUtils::getDescribe(dst); | 
					
						
							|  |  |  |     dstDes->dimensionFormat = srcDes->dimensionFormat; | 
					
						
							|  |  |  |     dstDes->tensorArrayAttr.reset(new TensorArrayAttr); | 
					
						
							|  |  |  |     dstDes->tensorArrayAttr->isDynamicSize = srcDes->tensorArrayAttr->isDynamicSize; | 
					
						
							|  |  |  |     dstDes->tensorArrayAttr->isIdenticalShape = srcDes->tensorArrayAttr->isIdenticalShape; | 
					
						
							|  |  |  |     dstDes->tensorArrayAttr->arraySize = srcDes->tensorArrayAttr->arraySize; | 
					
						
							|  |  |  |     dstDes->tensorArrayAttr->elemShape = srcDes->tensorArrayAttr->elemShape; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static void updateTensorArrayDims(Tensor* t) { | 
					
						
							|  |  |  |     auto des = TensorUtils::getDescribe(t); | 
					
						
							|  |  |  |     // shape : [Sum(elemShape)]
 | 
					
						
							|  |  |  |     t->buffer().dimensions = 1; | 
					
						
							|  |  |  |     int totalSize = 0; | 
					
						
							|  |  |  |     for (auto elem : des->tensorArrayAttr->elemShape) { | 
					
						
							|  |  |  |         int elemSize = 1; | 
					
						
							|  |  |  |         for (auto dim : elem) { | 
					
						
							|  |  |  |             elemSize *= dim; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         totalSize += elemSize; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     t->setLength(0, des->tensorArrayAttr->arraySize * totalSize); | 
					
						
							|  |  |  |     t->setLength(1, 1); | 
					
						
							|  |  |  |     t->setLength(2, 1); | 
					
						
							|  |  |  |     t->setLength(3, 1); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArray ============================
 | 
					
						
							|  |  |  | class TensorArrayComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : size
 | 
					
						
							|  |  |  |     // outputs: handle, flow_out
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(1 == inputs.size() && 2 == outputs.size()); | 
					
						
							|  |  |  |         auto param = op->main_as_TensorArray(); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |         for (int i = 0; i < 2; i++) { | 
					
						
							|  |  |  |             auto& output = outputs[i]; | 
					
						
							|  |  |  |             auto des = TensorUtils::getDescribe(output); | 
					
						
							|  |  |  |             // 1. set TensorArray attrs
 | 
					
						
							|  |  |  |             des->tensorArrayAttr.reset(new TensorArrayAttr); | 
					
						
							|  |  |  |             des->tensorArrayAttr->isDynamicSize = param->dynamic_size(); | 
					
						
							|  |  |  |             des->tensorArrayAttr->isIdenticalShape = param->identical_element_shapes(); | 
					
						
							|  |  |  |             if (param->element_shape() && param->element_shape()->size() > 0) { | 
					
						
							|  |  |  |                 std::vector<int> elemShape(param->element_shape()->size()); | 
					
						
							|  |  |  |                 for (int i = 0; i < param->element_shape()->size(); i++) { | 
					
						
							|  |  |  |                     elemShape[i] = param->element_shape()->Get(i); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 des->tensorArrayAttr->elemShape.emplace_back(std::move(elemShape)); | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |             des->tensorArrayAttr->arraySize = inputs[0]->host<uint32_t>()[0]; | 
					
						
							|  |  |  |             // 2. set dtype, dimension format and dims
 | 
					
						
							|  |  |  |             output->setType(param->T()); | 
					
						
							|  |  |  |             TensorUtils::getDescribe(output)->dimensionFormat = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |             updateTensorArrayDims(output); | 
					
						
							|  |  |  |             MNN_ASSERT(des->tensorArrayAttr != nullptr); | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE_INPUTS(TensorArrayComputer, OpType_TensorArray, {0}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArraySize ============================
 | 
					
						
							|  |  |  | class TensorArraySizeComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, flow_in
 | 
					
						
							|  |  |  |     // outputs: tensor
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(2 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         MNN_ASSERT(TensorUtils::getDescribe(inputs[1])->tensorArrayAttr != nullptr); | 
					
						
							|  |  |  |         outputs[0]->setType(DataType_DT_INT32); | 
					
						
							|  |  |  |         outputs[0]->buffer().dimensions    = 1; | 
					
						
							|  |  |  |         outputs[0]->setLength(0, 1); | 
					
						
							|  |  |  |         TensorUtils::getDescribe(outputs[0])->dimensionFormat = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE(TensorArraySizeComputer, OpType_TensorArraySize); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArrayRead ============================
 | 
					
						
							|  |  |  | class TensorArrayReadComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, index, flow_in
 | 
					
						
							|  |  |  |     // outputs: tensor
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(3 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         auto des = TensorUtils::getDescribe(inputs[2]); | 
					
						
							|  |  |  |         if (des->tensorArrayAttr == nullptr) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::vector<int> readElemShape; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         int readIndex = inputs[1]->host<uint32_t>()[0]; | 
					
						
							|  |  |  |         if (!des->tensorArrayAttr->isIdenticalShape && des->tensorArrayAttr->elemShape.size() > readIndex) { | 
					
						
							|  |  |  |             readElemShape = des->tensorArrayAttr->elemShape[readIndex]; | 
					
						
							|  |  |  |         } else if (des->tensorArrayAttr->elemShape.size() >= 1) { | 
					
						
							|  |  |  |             readElemShape = des->tensorArrayAttr->elemShape[0]; | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             MNN_ASSERT(false); | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         outputs[0]->setType(op->main_as_TensorArray()->T()); | 
					
						
							|  |  |  |         outputs[0]->buffer().dimensions    = readElemShape.size(); | 
					
						
							|  |  |  |         for (int i = 0; i < readElemShape.size(); i++) { | 
					
						
							|  |  |  |             outputs[0]->setLength(i, readElemShape[i]); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         TensorUtils::getDescribe(outputs[0])->dimensionFormat = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE_INPUTS(TensorArrayReadComputer, OpType_TensorArrayRead, {1}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArrayWrite ============================
 | 
					
						
							|  |  |  | class TensorArrayWriteComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, index, value, flow_in
 | 
					
						
							|  |  |  |     // outputs: flow_out
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(4 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         auto inDes  = TensorUtils::getDescribe(inputs[3]); | 
					
						
							|  |  |  |         auto outDes = TensorUtils::getDescribe(outputs[0]); | 
					
						
							|  |  |  |         if (inDes->tensorArrayAttr == nullptr) { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         copyTensorArrayAttribute(inputs[3], outputs[0]); | 
					
						
							|  |  |  |         outputs[0]->setType(op->main_as_TensorArray()->T()); | 
					
						
							|  |  |  |         int writeIndex = inputs[1]->host<uint32_t>()[0]; | 
					
						
							|  |  |  |         // update arraySize
 | 
					
						
							|  |  |  |         if (!inDes->tensorArrayAttr->isDynamicSize) { | 
					
						
							|  |  |  |             MNN_ASSERT(writeIndex < inDes->tensorArrayAttr->arraySize); | 
					
						
							|  |  |  |         } else if (writeIndex >= inDes->tensorArrayAttr->arraySize) { | 
					
						
							|  |  |  |             outDes->tensorArrayAttr->arraySize = writeIndex + 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // update elemShape
 | 
					
						
							|  |  |  |         auto writeShape = inputs[2]->shape(); | 
					
						
							|  |  |  |         if (outDes->tensorArrayAttr->isIdenticalShape) { | 
					
						
							|  |  |  |             if (outDes->tensorArrayAttr->elemShape.empty()) { | 
					
						
							|  |  |  |                 outDes->tensorArrayAttr->elemShape.push_back(writeShape); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |             } else { | 
					
						
							|  |  |  |                 outDes->tensorArrayAttr->elemShape[0] = writeShape; | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             for (int i = outDes->tensorArrayAttr->elemShape.size(); i <= writeIndex; i++) { | 
					
						
							|  |  |  |                 outDes->tensorArrayAttr->elemShape.push_back(writeShape); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             outDes->tensorArrayAttr->elemShape[writeIndex] = writeShape; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         updateTensorArrayDims(outputs[0]); | 
					
						
							|  |  |  |         MNN_ASSERT(outDes->tensorArrayAttr != nullptr); | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE_INPUTS(TensorArrayWriteComputer, OpType_TensorArrayWrite, {1}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArrayGather ============================
 | 
					
						
							|  |  |  | class TensorArrayGatherComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, indices, flow_in
 | 
					
						
							|  |  |  |     // outputs: tensor
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(3 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         auto inDes  = TensorUtils::getDescribe(inputs[2]); | 
					
						
							|  |  |  |         auto outDes = TensorUtils::getDescribe(outputs[0]); | 
					
						
							|  |  |  |         if (inDes->tensorArrayAttr == nullptr) { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto param = op->main_as_TensorArray(); | 
					
						
							|  |  |  |         outputs[0]->setType(param->T()); | 
					
						
							|  |  |  |         outDes->dimensionFormat = inDes->dimensionFormat; | 
					
						
							|  |  |  |         outputs[0]->buffer().dimensions = inputs[2]->buffer().dimensions; | 
					
						
							|  |  |  |         outputs[0]->setLength(0, inputs[1]->length(0)); | 
					
						
							|  |  |  |         // using param shape
 | 
					
						
							|  |  |  |         if (param->element_shape() && param->element_shape()->size() > 0) { | 
					
						
							|  |  |  |             outputs[0]->buffer().dimensions = param->element_shape()->size() + 1; | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |             MNN_ASSERT(param->element_shape()->size() == inDes->tensorArrayAttr->elemShape[0].size()); | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             for (int i = 0; i < param->element_shape()->size(); i++) { | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |                 int dimValue = param->element_shape()->Get(i); | 
					
						
							|  |  |  |                 if (dimValue < 0) { | 
					
						
							|  |  |  |                     dimValue = inDes->tensorArrayAttr->elemShape[0][i]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 outputs[0]->setLength(1 + i, dimValue); | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             if (inDes->tensorArrayAttr->elemShape.size() == 1) { | 
					
						
							|  |  |  |                 for (int i = 0; i < inDes->tensorArrayAttr->elemShape[0].size(); i++) { | 
					
						
							|  |  |  |                     outputs[0]->setLength(1 + i, inDes->tensorArrayAttr->elemShape[0][i]); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 MNN_ASSERT(false); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE_INPUTS(TensorArrayGatherComputer, OpType_TensorArrayGather, {1}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArrayScatter ============================
 | 
					
						
							|  |  |  | class TensorArrayScatterComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, indices, value, flow_in
 | 
					
						
							|  |  |  |     // outputs: flow_out
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(4 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         auto inDes  = TensorUtils::getDescribe(inputs[3]); | 
					
						
							|  |  |  |         auto outDes = TensorUtils::getDescribe(outputs[0]); | 
					
						
							|  |  |  |         if (inDes->tensorArrayAttr == nullptr) { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         copyTensorArrayAttribute(inputs[3], outputs[0]); | 
					
						
							|  |  |  |         for (int i = 0; i < inputs[1]->length(0); i++) { | 
					
						
							|  |  |  |             int writeIndex = inputs[1]->host<uint32_t>()[i]; | 
					
						
							|  |  |  |             if (!inDes->tensorArrayAttr->isDynamicSize) { | 
					
						
							|  |  |  |                 MNN_ASSERT(writeIndex < inDes->tensorArrayAttr->arraySize); | 
					
						
							|  |  |  |             } else if (writeIndex >= inDes->tensorArrayAttr->arraySize) { | 
					
						
							|  |  |  |                 outDes->tensorArrayAttr->arraySize = writeIndex + 1; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             std::vector<int> writeElemShape(inputs[2]->shape()); | 
					
						
							|  |  |  |             writeElemShape.erase(writeElemShape.begin()); | 
					
						
							|  |  |  |             if (outDes->tensorArrayAttr->elemShape.empty()) { | 
					
						
							|  |  |  |                 outDes->tensorArrayAttr->elemShape.emplace_back(std::move(writeElemShape)); | 
					
						
							|  |  |  |             } else { | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |                 outDes->tensorArrayAttr->elemShape[0] = writeElemShape; | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         outputs[0]->setType(op->main_as_TensorArray()->T()); | 
					
						
							|  |  |  |         updateTensorArrayDims(outputs[0]); | 
					
						
							|  |  |  |         MNN_ASSERT(outDes->tensorArrayAttr != nullptr); | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE_INPUTS(TensorArrayScatterComputer, OpType_TensorArrayScatter, {1}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArraySplit ============================
 | 
					
						
							|  |  |  | class TensorArraySplitComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, value, lengths, flow_in
 | 
					
						
							|  |  |  |     // outputs: flow_out
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(4 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         auto inDes = TensorUtils::getDescribe(inputs[3]); | 
					
						
							|  |  |  |         if (inDes->tensorArrayAttr == nullptr) { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         copyTensorArrayAttribute(inputs[3], outputs[0]); | 
					
						
							|  |  |  |         outputs[0]->setType(op->main_as_TensorArray()->T()); | 
					
						
							|  |  |  |         auto outDes = TensorUtils::getDescribe(outputs[0]); | 
					
						
							|  |  |  |         if (outDes->tensorArrayAttr->isIdenticalShape) { | 
					
						
							|  |  |  |             std::vector<int> writeElemShape(inputs[1]->shape()); | 
					
						
							|  |  |  |             outDes->tensorArrayAttr->arraySize = writeElemShape[0]; | 
					
						
							|  |  |  |             writeElemShape.erase(writeElemShape.begin()); | 
					
						
							|  |  |  |             outDes->tensorArrayAttr->elemShape.emplace_back(std::move(writeElemShape)); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             auto value = inputs[1]; | 
					
						
							|  |  |  |             auto lengths = inputs[2]; | 
					
						
							|  |  |  |             outDes->tensorArrayAttr->arraySize = lengths->length(0); | 
					
						
							|  |  |  |             std::vector<int> vShape(value->shape()); | 
					
						
							|  |  |  |             const int* lengthPtr = lengths->host<int>(); | 
					
						
							|  |  |  |             for (int i = 0; i < lengths->length(0); i++) { | 
					
						
							|  |  |  |                 auto elemShape = vShape; | 
					
						
							|  |  |  |                 elemShape[0] = lengthPtr[i]; | 
					
						
							|  |  |  |                 outDes->tensorArrayAttr->elemShape.emplace_back(std::move(elemShape)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         updateTensorArrayDims(outputs[0]); | 
					
						
							|  |  |  |         MNN_ASSERT(outDes->tensorArrayAttr != nullptr); | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE_INPUTS(TensorArraySplitComputer, OpType_TensorArraySplit, {2}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ============================ TensorArrayConcat ============================
 | 
					
						
							|  |  |  | class TensorArrayConcatComputer : public SizeComputer { | 
					
						
							|  |  |  |     // inputs : handle, flow_in
 | 
					
						
							|  |  |  |     // outputs: tensor
 | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(2 == inputs.size() && 1 == outputs.size()); | 
					
						
							|  |  |  |         auto inDes  = TensorUtils::getDescribe(inputs[1]); | 
					
						
							|  |  |  |         if (inDes->tensorArrayAttr == nullptr) { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         outputs[0]->setType(op->main_as_TensorArray()->T()); | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         if (inDes->tensorArrayAttr->elemShape.size() >= 1) { | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  |             outputs[0]->buffer().dimensions = inDes->tensorArrayAttr->elemShape[0].size() + 1; | 
					
						
							|  |  |  |             outputs[0]->setLength(0, inDes->tensorArrayAttr->arraySize); | 
					
						
							|  |  |  |             for (int i = 0; i < inDes->tensorArrayAttr->elemShape[0].size(); i++) { | 
					
						
							|  |  |  |                 outputs[0]->setLength(1 + i, inDes->tensorArrayAttr->elemShape[0][i]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | REGISTER_SHAPE(TensorArrayConcatComputer, OpType_TensorArrayConcat); | 
					
						
							|  |  |  | } // namespace MNN
 |