2021-01-06 16:29:37 +08:00
//
// ShapeTensorArray.cpp
// MNN
//
// Created by MNN on 2020/12/21.
// Copyright © 2018, Alibaba Group Holding Limited
//
2022-01-04 10:50:40 +08:00
# include <numeric>
2021-01-06 16:29:37 +08:00
# include "shape/SizeComputer.hpp"
# include "core/Macro.h"
# include "math.h"
namespace MNN {
static void copyTensorArrayAttribute ( const Tensor * src , Tensor * dst ) {
auto srcDes = TensorUtils : : getDescribe ( src ) ;
auto dstDes = TensorUtils : : getDescribe ( dst ) ;
dstDes - > dimensionFormat = srcDes - > dimensionFormat ;
dstDes - > tensorArrayAttr . reset ( new TensorArrayAttr ) ;
dstDes - > tensorArrayAttr - > isDynamicSize = srcDes - > tensorArrayAttr - > isDynamicSize ;
dstDes - > tensorArrayAttr - > isIdenticalShape = srcDes - > tensorArrayAttr - > isIdenticalShape ;
dstDes - > tensorArrayAttr - > arraySize = srcDes - > tensorArrayAttr - > arraySize ;
dstDes - > tensorArrayAttr - > elemShape = srcDes - > tensorArrayAttr - > elemShape ;
}
static void updateTensorArrayDims ( Tensor * t ) {
auto des = TensorUtils : : getDescribe ( t ) ;
// shape : [Sum(elemShape)]
t - > buffer ( ) . dimensions = 1 ;
2022-01-04 10:50:40 +08:00
int totalSize = 0 , arraySize = des - > tensorArrayAttr - > arraySize ;
2021-01-06 16:29:37 +08:00
for ( auto elem : des - > tensorArrayAttr - > elemShape ) {
int elemSize = 1 ;
for ( auto dim : elem ) {
elemSize * = dim ;
}
totalSize + = elemSize ;
}
2022-01-04 10:50:40 +08:00
if ( des - > tensorArrayAttr - > elemShape . size ( ) = = 1 & & arraySize > 1 ) {
totalSize * = arraySize ;
} else if ( totalSize = = 0 ) {
totalSize = 1 ; // bypass MNNV3 Dynamic Graph Executor zeroShape check
}
t - > setLength ( 0 , totalSize ) ;
2021-01-06 16:29:37 +08:00
t - > setLength ( 1 , 1 ) ;
t - > setLength ( 2 , 1 ) ;
t - > setLength ( 3 , 1 ) ;
}
// ============================ TensorArray ============================
class TensorArrayComputer : public SizeComputer {
// inputs : size
// outputs: handle, flow_out
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 1 = = inputs . size ( ) & & 2 = = outputs . size ( ) ) ;
auto param = op - > main_as_TensorArray ( ) ;
2021-02-07 10:45:07 +08:00
for ( int i = 0 ; i < 2 ; i + + ) {
auto & output = outputs [ i ] ;
auto des = TensorUtils : : getDescribe ( output ) ;
// 1. set TensorArray attrs
des - > tensorArrayAttr . reset ( new TensorArrayAttr ) ;
des - > tensorArrayAttr - > isDynamicSize = param - > dynamic_size ( ) ;
des - > tensorArrayAttr - > isIdenticalShape = param - > identical_element_shapes ( ) ;
if ( param - > element_shape ( ) & & param - > element_shape ( ) - > size ( ) > 0 ) {
std : : vector < int > elemShape ( param - > element_shape ( ) - > size ( ) ) ;
for ( int i = 0 ; i < param - > element_shape ( ) - > size ( ) ; i + + ) {
elemShape [ i ] = param - > element_shape ( ) - > Get ( i ) ;
2021-10-14 14:59:14 +08:00
if ( elemShape [ i ] < 0 ) {
elemShape [ i ] = 0 ;
}
2021-02-07 10:45:07 +08:00
}
des - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( elemShape ) ) ;
2021-01-06 16:29:37 +08:00
}
2021-02-07 10:45:07 +08:00
des - > tensorArrayAttr - > arraySize = inputs [ 0 ] - > host < uint32_t > ( ) [ 0 ] ;
// 2. set dtype, dimension format and dims
output - > setType ( param - > T ( ) ) ;
2022-01-04 10:50:40 +08:00
TensorUtils : : getDescribe ( output ) - > dimensionFormat = op - > defaultDimentionFormat ( ) ;
2021-02-07 10:45:07 +08:00
updateTensorArrayDims ( output ) ;
MNN_ASSERT ( des - > tensorArrayAttr ! = nullptr ) ;
2021-01-06 16:29:37 +08:00
}
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArrayComputer , OpType_TensorArray , { 0 } ) ;
// ============================ TensorArraySize ============================
class TensorArraySizeComputer : public SizeComputer {
// inputs : handle, flow_in
// outputs: tensor
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 2 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
MNN_ASSERT ( TensorUtils : : getDescribe ( inputs [ 1 ] ) - > tensorArrayAttr ! = nullptr ) ;
outputs [ 0 ] - > setType ( DataType_DT_INT32 ) ;
outputs [ 0 ] - > buffer ( ) . dimensions = 1 ;
outputs [ 0 ] - > setLength ( 0 , 1 ) ;
2022-01-04 10:50:40 +08:00
TensorUtils : : getDescribe ( outputs [ 0 ] ) - > dimensionFormat = TensorUtils : : getDescribe ( inputs [ 1 ] ) - > dimensionFormat ;
2021-01-06 16:29:37 +08:00
return true ;
}
} ;
REGISTER_SHAPE ( TensorArraySizeComputer , OpType_TensorArraySize ) ;
// ============================ TensorArrayRead ============================
class TensorArrayReadComputer : public SizeComputer {
// inputs : handle, index, flow_in
// outputs: tensor
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 3 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto des = TensorUtils : : getDescribe ( inputs [ 2 ] ) ;
if ( des - > tensorArrayAttr = = nullptr ) {
return false ;
}
std : : vector < int > readElemShape ;
2021-04-08 15:34:23 +08:00
int readIndex = inputs [ 1 ] - > host < uint32_t > ( ) [ 0 ] ;
if ( ! des - > tensorArrayAttr - > isIdenticalShape & & des - > tensorArrayAttr - > elemShape . size ( ) > readIndex ) {
readElemShape = des - > tensorArrayAttr - > elemShape [ readIndex ] ;
} else if ( des - > tensorArrayAttr - > elemShape . size ( ) > = 1 ) {
readElemShape = des - > tensorArrayAttr - > elemShape [ 0 ] ;
2021-01-06 16:29:37 +08:00
} else {
2021-04-08 15:34:23 +08:00
MNN_ASSERT ( false ) ;
2021-01-06 16:29:37 +08:00
}
2024-11-18 14:37:45 +08:00
outputs [ 0 ] - > buffer ( ) . type = inputs [ 2 ] - > buffer ( ) . type ;
2021-01-06 16:29:37 +08:00
outputs [ 0 ] - > buffer ( ) . dimensions = readElemShape . size ( ) ;
for ( int i = 0 ; i < readElemShape . size ( ) ; i + + ) {
outputs [ 0 ] - > setLength ( i , readElemShape [ i ] ) ;
}
2022-01-04 10:50:40 +08:00
TensorUtils : : getDescribe ( outputs [ 0 ] ) - > dimensionFormat = TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat ;
2021-01-06 16:29:37 +08:00
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArrayReadComputer , OpType_TensorArrayRead , { 1 } ) ;
// ============================ TensorArrayWrite ============================
class TensorArrayWriteComputer : public SizeComputer {
// inputs : handle, index, value, flow_in
// outputs: flow_out
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 4 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 3 ] ) ;
auto outDes = TensorUtils : : getDescribe ( outputs [ 0 ] ) ;
if ( inDes - > tensorArrayAttr = = nullptr ) {
MNN_ASSERT ( false ) ;
return false ;
}
2022-01-04 10:50:40 +08:00
if ( TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat ! = inDes - > dimensionFormat ) {
MNN_ASSERT ( false ) ;
return false ;
}
2021-01-06 16:29:37 +08:00
copyTensorArrayAttribute ( inputs [ 3 ] , outputs [ 0 ] ) ;
2024-11-18 14:37:45 +08:00
outputs [ 0 ] - > buffer ( ) . type = inputs [ 2 ] - > buffer ( ) . type ;
2021-01-06 16:29:37 +08:00
int writeIndex = inputs [ 1 ] - > host < uint32_t > ( ) [ 0 ] ;
// update arraySize
if ( ! inDes - > tensorArrayAttr - > isDynamicSize ) {
MNN_ASSERT ( writeIndex < inDes - > tensorArrayAttr - > arraySize ) ;
} else if ( writeIndex > = inDes - > tensorArrayAttr - > arraySize ) {
outDes - > tensorArrayAttr - > arraySize = writeIndex + 1 ;
}
// update elemShape
auto writeShape = inputs [ 2 ] - > shape ( ) ;
if ( outDes - > tensorArrayAttr - > isIdenticalShape ) {
if ( outDes - > tensorArrayAttr - > elemShape . empty ( ) ) {
outDes - > tensorArrayAttr - > elemShape . push_back ( writeShape ) ;
2021-02-07 10:45:07 +08:00
} else {
outDes - > tensorArrayAttr - > elemShape [ 0 ] = writeShape ;
2021-01-06 16:29:37 +08:00
}
} else {
for ( int i = outDes - > tensorArrayAttr - > elemShape . size ( ) ; i < = writeIndex ; i + + ) {
outDes - > tensorArrayAttr - > elemShape . push_back ( writeShape ) ;
}
outDes - > tensorArrayAttr - > elemShape [ writeIndex ] = writeShape ;
}
updateTensorArrayDims ( outputs [ 0 ] ) ;
MNN_ASSERT ( outDes - > tensorArrayAttr ! = nullptr ) ;
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArrayWriteComputer , OpType_TensorArrayWrite , { 1 } ) ;
// ============================ TensorArrayGather ============================
class TensorArrayGatherComputer : public SizeComputer {
// inputs : handle, indices, flow_in
// outputs: tensor
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 3 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 2 ] ) ;
auto outDes = TensorUtils : : getDescribe ( outputs [ 0 ] ) ;
if ( inDes - > tensorArrayAttr = = nullptr ) {
MNN_ASSERT ( false ) ;
return false ;
}
auto param = op - > main_as_TensorArray ( ) ;
outputs [ 0 ] - > setType ( param - > T ( ) ) ;
outDes - > dimensionFormat = inDes - > dimensionFormat ;
outputs [ 0 ] - > buffer ( ) . dimensions = inputs [ 2 ] - > buffer ( ) . dimensions ;
outputs [ 0 ] - > setLength ( 0 , inputs [ 1 ] - > length ( 0 ) ) ;
// using param shape
if ( param - > element_shape ( ) & & param - > element_shape ( ) - > size ( ) > 0 ) {
outputs [ 0 ] - > buffer ( ) . dimensions = param - > element_shape ( ) - > size ( ) + 1 ;
2021-02-07 10:45:07 +08:00
MNN_ASSERT ( param - > element_shape ( ) - > size ( ) = = inDes - > tensorArrayAttr - > elemShape [ 0 ] . size ( ) ) ;
2021-01-06 16:29:37 +08:00
for ( int i = 0 ; i < param - > element_shape ( ) - > size ( ) ; i + + ) {
2021-02-07 10:45:07 +08:00
int dimValue = param - > element_shape ( ) - > Get ( i ) ;
if ( dimValue < 0 ) {
dimValue = inDes - > tensorArrayAttr - > elemShape [ 0 ] [ i ] ;
}
outputs [ 0 ] - > setLength ( 1 + i , dimValue ) ;
2021-01-06 16:29:37 +08:00
}
} else {
if ( inDes - > tensorArrayAttr - > elemShape . size ( ) = = 1 ) {
for ( int i = 0 ; i < inDes - > tensorArrayAttr - > elemShape [ 0 ] . size ( ) ; i + + ) {
outputs [ 0 ] - > setLength ( 1 + i , inDes - > tensorArrayAttr - > elemShape [ 0 ] [ i ] ) ;
}
} else {
MNN_ASSERT ( false ) ;
}
}
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArrayGatherComputer , OpType_TensorArrayGather , { 1 } ) ;
// ============================ TensorArrayScatter ============================
class TensorArrayScatterComputer : public SizeComputer {
// inputs : handle, indices, value, flow_in
// outputs: flow_out
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 4 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 3 ] ) ;
auto outDes = TensorUtils : : getDescribe ( outputs [ 0 ] ) ;
if ( inDes - > tensorArrayAttr = = nullptr ) {
MNN_ASSERT ( false ) ;
return false ;
}
2022-01-04 10:50:40 +08:00
if ( TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat ! = inDes - > dimensionFormat ) {
MNN_ASSERT ( false ) ;
return false ;
}
2021-01-06 16:29:37 +08:00
copyTensorArrayAttribute ( inputs [ 3 ] , outputs [ 0 ] ) ;
for ( int i = 0 ; i < inputs [ 1 ] - > length ( 0 ) ; i + + ) {
int writeIndex = inputs [ 1 ] - > host < uint32_t > ( ) [ i ] ;
if ( ! inDes - > tensorArrayAttr - > isDynamicSize ) {
MNN_ASSERT ( writeIndex < inDes - > tensorArrayAttr - > arraySize ) ;
} else if ( writeIndex > = inDes - > tensorArrayAttr - > arraySize ) {
outDes - > tensorArrayAttr - > arraySize = writeIndex + 1 ;
}
std : : vector < int > writeElemShape ( inputs [ 2 ] - > shape ( ) ) ;
writeElemShape . erase ( writeElemShape . begin ( ) ) ;
if ( outDes - > tensorArrayAttr - > elemShape . empty ( ) ) {
outDes - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( writeElemShape ) ) ;
} else {
2021-02-07 10:45:07 +08:00
outDes - > tensorArrayAttr - > elemShape [ 0 ] = writeElemShape ;
2021-01-06 16:29:37 +08:00
}
}
2024-11-18 14:37:45 +08:00
outputs [ 0 ] - > buffer ( ) . type = inputs [ 3 ] - > buffer ( ) . type ;
2021-01-06 16:29:37 +08:00
updateTensorArrayDims ( outputs [ 0 ] ) ;
MNN_ASSERT ( outDes - > tensorArrayAttr ! = nullptr ) ;
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArrayScatterComputer , OpType_TensorArrayScatter , { 1 } ) ;
// ============================ TensorArraySplit ============================
class TensorArraySplitComputer : public SizeComputer {
// inputs : handle, value, lengths, flow_in
// outputs: flow_out
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 4 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 3 ] ) ;
if ( inDes - > tensorArrayAttr = = nullptr ) {
MNN_ASSERT ( false ) ;
return false ;
}
2022-01-04 10:50:40 +08:00
auto taParam = op - > main_as_TensorArray ( ) ;
int splitAxis = ( taParam - > axis ( ) + inputs [ 1 ] - > dimensions ( ) ) % inputs [ 1 ] - > dimensions ( ) ;
int keepdims = taParam - > keepdims ( ) ;
2021-01-06 16:29:37 +08:00
copyTensorArrayAttribute ( inputs [ 3 ] , outputs [ 0 ] ) ;
outputs [ 0 ] - > setType ( op - > main_as_TensorArray ( ) - > T ( ) ) ;
auto outDes = TensorUtils : : getDescribe ( outputs [ 0 ] ) ;
if ( outDes - > tensorArrayAttr - > isIdenticalShape ) {
std : : vector < int > writeElemShape ( inputs [ 1 ] - > shape ( ) ) ;
2022-01-04 10:50:40 +08:00
outDes - > tensorArrayAttr - > arraySize = writeElemShape [ splitAxis ] ;
if ( keepdims ) {
writeElemShape [ splitAxis ] = 1 ;
} else {
writeElemShape . erase ( writeElemShape . begin ( ) + splitAxis ) ;
}
2021-01-06 16:29:37 +08:00
outDes - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( writeElemShape ) ) ;
} else {
auto value = inputs [ 1 ] ;
auto lengths = inputs [ 2 ] ;
2022-01-04 10:50:40 +08:00
bool scalarSplit = ( lengths - > elementSize ( ) = = 1 ) ;
std : : vector < int > vShape ( value - > shape ( ) ) ;
int totalLen = value - > shape ( ) [ splitAxis ] , splitNum ;
if ( scalarSplit ) {
splitNum = UP_DIV ( totalLen , lengths - > host < int > ( ) [ 0 ] ) ;
MNN_ASSERT ( keepdims | | lengths - > host < int > ( ) [ 0 ] = = 1 ) ;
} else {
splitNum = lengths - > length ( 0 ) ;
MNN_ASSERT ( std : : accumulate ( lengths - > host < int > ( ) , lengths - > host < int > ( ) + splitNum , 0 ) = = totalLen ) ;
}
outDes - > tensorArrayAttr - > arraySize = splitNum ;
for ( int i = 0 ; i < splitNum ; + + i ) {
auto elemShape = vShape ;
if ( scalarSplit ) {
if ( ! keepdims ) {
elemShape . erase ( elemShape . begin ( ) + splitAxis ) ;
} else {
int splitLen = lengths - > host < int > ( ) [ 0 ] ;
elemShape [ splitAxis ] = ALIMIN ( splitLen , totalLen - i * splitLen ) ;
}
} else {
elemShape [ splitAxis ] = lengths - > host < int > ( ) [ i ] ;
2021-09-18 15:52:30 +08:00
}
2022-01-04 10:50:40 +08:00
outDes - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( elemShape ) ) ;
2021-01-06 16:29:37 +08:00
}
}
updateTensorArrayDims ( outputs [ 0 ] ) ;
MNN_ASSERT ( outDes - > tensorArrayAttr ! = nullptr ) ;
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArraySplitComputer , OpType_TensorArraySplit , { 2 } ) ;
// ============================ TensorArrayConcat ============================
class TensorArrayConcatComputer : public SizeComputer {
// inputs : handle, flow_in
// outputs: tensor
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 2 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 1 ] ) ;
2022-01-04 10:50:40 +08:00
if ( inDes - > tensorArrayAttr = = nullptr | | inDes - > tensorArrayAttr - > arraySize = = 0 ) {
2021-01-06 16:29:37 +08:00
MNN_ASSERT ( false ) ;
return false ;
}
2022-01-04 10:50:40 +08:00
copyTensorArrayAttribute ( inputs [ 1 ] , outputs [ 0 ] ) ;
auto tpParam = op - > main_as_TensorArray ( ) ;
int concatAxis = tpParam - > axis ( ) , newAxis = tpParam - > new_axis ( ) ;
2024-11-18 14:37:45 +08:00
outputs [ 0 ] - > buffer ( ) . type = inputs [ 1 ] - > buffer ( ) . type ;
2022-01-04 10:50:40 +08:00
const auto & elemShapes = inDes - > tensorArrayAttr - > elemShape ;
auto outShape = elemShapes [ 0 ] ;
bool valid = true ; // avoid use MNN_ASSERT because it's no-op in release mode
for ( int i = 1 ; valid & & ( i < elemShapes . size ( ) ) ; + + i ) {
auto elemShape = elemShapes [ inDes - > tensorArrayAttr - > isIdenticalShape ? 0 : i ] ;
valid & = ( outShape . size ( ) = = elemShape . size ( ) ) ;
if ( newAxis ) {
valid & = ( std : : equal ( outShape . begin ( ) , outShape . end ( ) , elemShape . begin ( ) ) ) ;
} else {
valid & = ( std : : equal ( outShape . begin ( ) , outShape . begin ( ) + concatAxis , elemShape . begin ( ) ) ) ;
valid & = ( std : : equal ( outShape . begin ( ) + concatAxis + 1 , outShape . end ( ) , elemShape . begin ( ) + concatAxis + 1 ) ) ;
outShape [ concatAxis ] + = elemShape [ concatAxis ] ;
2021-01-06 16:29:37 +08:00
}
2022-01-04 10:50:40 +08:00
}
if ( ! valid ) {
MNN_ERROR ( " Invalid input, elements in seq have different shape [new_axis=true need same shape, new_axis=false need same shape except concat_axis dim] \n " ) ;
return false ;
}
if ( newAxis ) {
outShape . insert ( outShape . begin ( ) + concatAxis , inDes - > tensorArrayAttr - > arraySize ) ;
}
outputs [ 0 ] - > buffer ( ) . dimensions = outShape . size ( ) ;
for ( int i = 0 ; i < outShape . size ( ) ; + + i ) {
outputs [ 0 ] - > setLength ( i , outShape [ i ] ) ;
}
return true ;
}
} ;
REGISTER_SHAPE ( TensorArrayConcatComputer , OpType_TensorArrayConcat ) ;
// ============================ TensorArrayInsert ============================
class TensorArrayInsertComputer : public SizeComputer {
// inputs : handle, position, value, flow_in
// outputs: flow_out
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 4 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 3 ] ) ;
if ( inDes - > tensorArrayAttr = = nullptr ) {
MNN_ASSERT ( false ) ;
return false ;
}
if ( TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat ! = inDes - > dimensionFormat ) {
MNN_ASSERT ( false ) ;
return false ;
}
MNN_ASSERT ( inDes - > tensorArrayAttr - > isDynamicSize ) ;
2022-06-10 10:39:50 +08:00
2022-01-04 10:50:40 +08:00
copyTensorArrayAttribute ( inputs [ 3 ] , outputs [ 0 ] ) ;
auto outSeq = TensorUtils : : getDescribe ( outputs [ 0 ] ) - > tensorArrayAttr ;
outputs [ 0 ] - > buffer ( ) . type = inputs [ 3 ] - > buffer ( ) . type ;
int inSeqSize = inDes - > tensorArrayAttr - > arraySize , insertIndex = inputs [ 1 ] - > host < int32_t > ( ) [ 0 ] ;
MNN_ASSERT ( insertIndex > = - inSeqSize & & insertIndex < = inSeqSize ) ; // [-n, n]
insertIndex + = ( insertIndex < 0 ? inSeqSize : 0 ) ;
// update arraySize
outSeq - > arraySize + = 1 ;
// update elemShape
auto insertShape = inputs [ 2 ] - > shape ( ) ;
auto & outSeqShapes = outSeq - > elemShape ;
if ( outSeq - > isIdenticalShape & & ! outSeqShapes . empty ( ) ) {
MNN_ASSERT ( std : : equal ( insertShape . begin ( ) , insertShape . end ( ) , outSeqShapes [ 0 ] . begin ( ) ) ) ;
2021-01-06 16:29:37 +08:00
} else {
2022-01-04 10:50:40 +08:00
outSeqShapes . insert ( outSeqShapes . begin ( ) + insertIndex , insertShape ) ;
}
updateTensorArrayDims ( outputs [ 0 ] ) ;
return true ;
}
} ;
REGISTER_SHAPE_INPUTS ( TensorArrayInsertComputer , OpType_TensorArrayInsert , { 1 } ) ;
// ============================ TensorArrayErase ============================
class TensorArrayEraseComputer : public SizeComputer {
// inputs : handle, position, flow_in
// outputs: flow_out
virtual bool onComputeSize ( const MNN : : Op * op , const std : : vector < Tensor * > & inputs ,
const std : : vector < Tensor * > & outputs ) const override {
MNN_ASSERT ( 3 = = inputs . size ( ) & & 1 = = outputs . size ( ) ) ;
auto inDes = TensorUtils : : getDescribe ( inputs [ 2 ] ) ;
if ( inDes - > tensorArrayAttr = = nullptr ) {
2021-01-06 16:29:37 +08:00
MNN_ASSERT ( false ) ;
2022-01-04 10:50:40 +08:00
return false ;
2021-01-06 16:29:37 +08:00
}
2022-01-04 10:50:40 +08:00
MNN_ASSERT ( inDes - > tensorArrayAttr - > isDynamicSize ) ;
2022-06-10 10:39:50 +08:00
2022-01-04 10:50:40 +08:00
copyTensorArrayAttribute ( inputs [ 2 ] , outputs [ 0 ] ) ;
auto outSeq = TensorUtils : : getDescribe ( outputs [ 0 ] ) - > tensorArrayAttr ;
outputs [ 0 ] - > buffer ( ) . type = inputs [ 2 ] - > buffer ( ) . type ;
int inSeqSize = outSeq - > arraySize , eraseIndex = inputs [ 1 ] - > host < int32_t > ( ) [ 0 ] ;
MNN_ASSERT ( eraseIndex > = - inSeqSize & & eraseIndex < inSeqSize ) ; // [-n, n-1]
eraseIndex + = ( eraseIndex < 0 ? inSeqSize : 0 ) ;
// update arraySize
outSeq - > arraySize - = 1 ;
// update elemShape
if ( ! outSeq - > isIdenticalShape ) {
outSeq - > elemShape . erase ( outSeq - > elemShape . begin ( ) + eraseIndex ) ;
}
updateTensorArrayDims ( outputs [ 0 ] ) ;
2021-01-06 16:29:37 +08:00
return true ;
}
} ;
2022-01-04 10:50:40 +08:00
REGISTER_SHAPE_INPUTS ( TensorArrayEraseComputer , OpType_TensorArrayErase , { 1 } ) ;
2021-01-06 16:29:37 +08:00
} // namespace MNN