2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								//
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//  ShapeTensorArray.cpp
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//  MNN
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//  Created by MNN on 2020/12/21.
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//  Copyright © 2018, Alibaba Group Holding Limited
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# include  <numeric> 
  
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# include  "shape/SizeComputer.hpp" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "core/Macro.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "math.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								namespace  MNN  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								static  void  copyTensorArrayAttribute ( const  Tensor *  src ,  Tensor *  dst )  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    auto  srcDes  =  TensorUtils : : getDescribe ( src ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    auto  dstDes  =  TensorUtils : : getDescribe ( dst ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dstDes - > dimensionFormat  =  srcDes - > dimensionFormat ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dstDes - > tensorArrayAttr . reset ( new  TensorArrayAttr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dstDes - > tensorArrayAttr - > isDynamicSize  =  srcDes - > tensorArrayAttr - > isDynamicSize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dstDes - > tensorArrayAttr - > isIdenticalShape  =  srcDes - > tensorArrayAttr - > isIdenticalShape ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dstDes - > tensorArrayAttr - > arraySize  =  srcDes - > tensorArrayAttr - > arraySize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dstDes - > tensorArrayAttr - > elemShape  =  srcDes - > tensorArrayAttr - > elemShape ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								static  void  updateTensorArrayDims ( Tensor *  t )  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    auto  des  =  TensorUtils : : getDescribe ( t ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // shape : [Sum(elemShape)]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    t - > buffer ( ) . dimensions  =  1 ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    int  totalSize  =  0 ,  arraySize  =  des - > tensorArrayAttr - > arraySize ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    for  ( auto  elem  :  des - > tensorArrayAttr - > elemShape )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  elemSize  =  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  ( auto  dim  :  elem )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            elemSize  * =  dim ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        totalSize  + =  elemSize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  ( des - > tensorArrayAttr - > elemShape . size ( )  = =  1  & &  arraySize  >  1 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        totalSize  * =  arraySize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }  else  if  ( totalSize  = =  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        totalSize  =  1 ;  // bypass MNNV3 Dynamic Graph Executor zeroShape check
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    t - > setLength ( 0 ,  totalSize ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    t - > setLength ( 1 ,  1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    t - > setLength ( 2 ,  1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    t - > setLength ( 3 ,  1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArray ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : size
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: handle, flow_out
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 1  = =  inputs . size ( )  & &  2  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  param  =  op - > main_as_TensorArray ( ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        for  ( int  i  =  0 ;  i  <  2 ;  i + + )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            auto &  output  =  outputs [ i ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            auto  des  =  TensorUtils : : getDescribe ( output ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            // 1. set TensorArray attrs
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            des - > tensorArrayAttr . reset ( new  TensorArrayAttr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            des - > tensorArrayAttr - > isDynamicSize  =  param - > dynamic_size ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            des - > tensorArrayAttr - > isIdenticalShape  =  param - > identical_element_shapes ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( param - > element_shape ( )  & &  param - > element_shape ( ) - > size ( )  >  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                std : : vector < int >  elemShape ( param - > element_shape ( ) - > size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                for  ( int  i  =  0 ;  i  <  param - > element_shape ( ) - > size ( ) ;  i + + )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    elemShape [ i ]  =  param - > element_shape ( ) - > Get ( i ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-10-14 14:59:14 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    if  ( elemShape [ i ]  <  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                        elemShape [ i ]  =  0 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    } 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                des - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( elemShape ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            des - > tensorArrayAttr - > arraySize  =  inputs [ 0 ] - > host < uint32_t > ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            // 2. set dtype, dimension format and dims
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            output - > setType ( param - > T ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            TensorUtils : : getDescribe ( output ) - > dimensionFormat  =  op - > defaultDimentionFormat ( ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            updateTensorArrayDims ( output ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( des - > tensorArrayAttr  ! =  nullptr ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayComputer ,  OpType_TensorArray ,  { 0 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArraySize ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArraySizeComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: tensor
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 2  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( TensorUtils : : getDescribe ( inputs [ 1 ] ) - > tensorArrayAttr  ! =  nullptr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( DataType_DT_INT32 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > buffer ( ) . dimensions     =  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setLength ( 0 ,  1 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        TensorUtils : : getDescribe ( outputs [ 0 ] ) - > dimensionFormat  =  TensorUtils : : getDescribe ( inputs [ 1 ] ) - > dimensionFormat ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE ( TensorArraySizeComputer ,  OpType_TensorArraySize ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayRead ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayReadComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, index, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: tensor
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 3  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  des  =  TensorUtils : : getDescribe ( inputs [ 2 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( des - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        std : : vector < int >  readElemShape ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-04-08 15:34:23 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        int  readIndex  =  inputs [ 1 ] - > host < uint32_t > ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( ! des - > tensorArrayAttr - > isIdenticalShape  & &  des - > tensorArrayAttr - > elemShape . size ( )  >  readIndex )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            readElemShape  =  des - > tensorArrayAttr - > elemShape [ readIndex ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }  else  if  ( des - > tensorArrayAttr - > elemShape . size ( )  > =  1 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            readElemShape  =  des - > tensorArrayAttr - > elemShape [ 0 ] ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        }  else  { 
							 
						 
					
						
							
								
									
										
										
										
											2021-04-08 15:34:23 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( op - > main_as_TensorArray ( ) - > T ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > buffer ( ) . dimensions     =  readElemShape . size ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  ( int  i  =  0 ;  i  <  readElemShape . size ( ) ;  i + + )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outputs [ 0 ] - > setLength ( i ,  readElemShape [ i ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        TensorUtils : : getDescribe ( outputs [ 0 ] ) - > dimensionFormat  =  TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayReadComputer ,  OpType_TensorArrayRead ,  { 1 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayWrite ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayWriteComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, index, value, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: flow_out
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 4  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes   =  TensorUtils : : getDescribe ( inputs [ 3 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outDes  =  TensorUtils : : getDescribe ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  ( TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat  ! =  inDes - > dimensionFormat )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        copyTensorArrayAttribute ( inputs [ 3 ] ,  outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( op - > main_as_TensorArray ( ) - > T ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  writeIndex  =  inputs [ 1 ] - > host < uint32_t > ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // update arraySize
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( ! inDes - > tensorArrayAttr - > isDynamicSize )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( writeIndex  <  inDes - > tensorArrayAttr - > arraySize ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }  else  if  ( writeIndex  > =  inDes - > tensorArrayAttr - > arraySize )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outDes - > tensorArrayAttr - > arraySize  =  writeIndex  +  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // update elemShape
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  writeShape  =  inputs [ 2 ] - > shape ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( outDes - > tensorArrayAttr - > isIdenticalShape )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( outDes - > tensorArrayAttr - > elemShape . empty ( ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > elemShape . push_back ( writeShape ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > elemShape [ 0 ]  =  writeShape ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            for  ( int  i  =  outDes - > tensorArrayAttr - > elemShape . size ( ) ;  i  < =  writeIndex ;  i + + )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > elemShape . push_back ( writeShape ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outDes - > tensorArrayAttr - > elemShape [ writeIndex ]  =  writeShape ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        updateTensorArrayDims ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( outDes - > tensorArrayAttr  ! =  nullptr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayWriteComputer ,  OpType_TensorArrayWrite ,  { 1 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayGather ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayGatherComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, indices, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: tensor
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 3  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes   =  TensorUtils : : getDescribe ( inputs [ 2 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outDes  =  TensorUtils : : getDescribe ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  param  =  op - > main_as_TensorArray ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( param - > T ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outDes - > dimensionFormat  =  inDes - > dimensionFormat ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > buffer ( ) . dimensions  =  inputs [ 2 ] - > buffer ( ) . dimensions ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setLength ( 0 ,  inputs [ 1 ] - > length ( 0 ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // using param shape
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( param - > element_shape ( )  & &  param - > element_shape ( ) - > size ( )  >  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outputs [ 0 ] - > buffer ( ) . dimensions  =  param - > element_shape ( ) - > size ( )  +  1 ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            MNN_ASSERT ( param - > element_shape ( ) - > size ( )  = =  inDes - > tensorArrayAttr - > elemShape [ 0 ] . size ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            for  ( int  i  =  0 ;  i  <  param - > element_shape ( ) - > size ( ) ;  i + + )  { 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                int  dimValue  =  param - > element_shape ( ) - > Get ( i ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                if  ( dimValue  <  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    dimValue  =  inDes - > tensorArrayAttr - > elemShape [ 0 ] [ i ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outputs [ 0 ] - > setLength ( 1  +  i ,  dimValue ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( inDes - > tensorArrayAttr - > elemShape . size ( )  = =  1 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                for  ( int  i  =  0 ;  i  <  inDes - > tensorArrayAttr - > elemShape [ 0 ] . size ( ) ;  i + + )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    outputs [ 0 ] - > setLength ( 1  +  i ,  inDes - > tensorArrayAttr - > elemShape [ 0 ] [ i ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayGatherComputer ,  OpType_TensorArrayGather ,  { 1 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayScatter ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayScatterComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, indices, value, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: flow_out
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 4  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes   =  TensorUtils : : getDescribe ( inputs [ 3 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outDes  =  TensorUtils : : getDescribe ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  ( TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat  ! =  inDes - > dimensionFormat )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        copyTensorArrayAttribute ( inputs [ 3 ] ,  outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  ( int  i  =  0 ;  i  <  inputs [ 1 ] - > length ( 0 ) ;  i + + )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            int  writeIndex  =  inputs [ 1 ] - > host < uint32_t > ( ) [ i ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( ! inDes - > tensorArrayAttr - > isDynamicSize )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                MNN_ASSERT ( writeIndex  <  inDes - > tensorArrayAttr - > arraySize ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            }  else  if  ( writeIndex  > =  inDes - > tensorArrayAttr - > arraySize )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > arraySize  =  writeIndex  +  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            std : : vector < int >  writeElemShape ( inputs [ 2 ] - > shape ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            writeElemShape . erase ( writeElemShape . begin ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( outDes - > tensorArrayAttr - > elemShape . empty ( ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( writeElemShape ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            }  else  { 
							 
						 
					
						
							
								
									
										
										
										
											2021-02-07 10:45:07 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > elemShape [ 0 ]  =  writeElemShape ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( op - > main_as_TensorArray ( ) - > T ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        updateTensorArrayDims ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( outDes - > tensorArrayAttr  ! =  nullptr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayScatterComputer ,  OpType_TensorArrayScatter ,  { 1 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArraySplit ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArraySplitComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, value, lengths, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: flow_out
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 4  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes  =  TensorUtils : : getDescribe ( inputs [ 3 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        auto  taParam  =  op - > main_as_TensorArray ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  splitAxis  =  ( taParam - > axis ( )  +  inputs [ 1 ] - > dimensions ( ) )  %  inputs [ 1 ] - > dimensions ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  keepdims  =  taParam - > keepdims ( ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        copyTensorArrayAttribute ( inputs [ 3 ] ,  outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( op - > main_as_TensorArray ( ) - > T ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outDes  =  TensorUtils : : getDescribe ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( outDes - > tensorArrayAttr - > isIdenticalShape )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            std : : vector < int >  writeElemShape ( inputs [ 1 ] - > shape ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            outDes - > tensorArrayAttr - > arraySize  =  writeElemShape [ splitAxis ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( keepdims )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                writeElemShape [ splitAxis ]  =  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                writeElemShape . erase ( writeElemShape . begin ( )  +  splitAxis ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            outDes - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( writeElemShape ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            auto  value  =  inputs [ 1 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            auto  lengths  =  inputs [ 2 ] ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            bool  scalarSplit  =  ( lengths - > elementSize ( )  = =  1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            std : : vector < int >  vShape ( value - > shape ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            int  totalLen  =  value - > shape ( ) [ splitAxis ] ,  splitNum ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( scalarSplit )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                splitNum  =  UP_DIV ( totalLen ,  lengths - > host < int > ( ) [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                MNN_ASSERT ( keepdims  | |  lengths - > host < int > ( ) [ 0 ]  = =  1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                splitNum  =  lengths - > length ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                MNN_ASSERT ( std : : accumulate ( lengths - > host < int > ( ) ,  lengths - > host < int > ( )  +  splitNum ,  0 )  = =  totalLen ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outDes - > tensorArrayAttr - > arraySize  =  splitNum ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            for  ( int  i  =  0 ;  i  <  splitNum ;  + + i )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                auto  elemShape  =  vShape ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                if  ( scalarSplit )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    if  ( ! keepdims )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                        elemShape . erase ( elemShape . begin ( )  +  splitAxis ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                        int  splitLen  =  lengths - > host < int > ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                        elemShape [ splitAxis ]  =  ALIMIN ( splitLen ,  totalLen  -  i  *  splitLen ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    elemShape [ splitAxis ]  =  lengths - > host < int > ( ) [ i ] ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-09-18 15:52:30 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                outDes - > tensorArrayAttr - > elemShape . emplace_back ( std : : move ( elemShape ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        updateTensorArrayDims ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( outDes - > tensorArrayAttr  ! =  nullptr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArraySplitComputer ,  OpType_TensorArraySplit ,  { 2 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayConcat ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayConcatComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: tensor
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 2  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes   =  TensorUtils : : getDescribe ( inputs [ 1 ] ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr  | |  inDes - > tensorArrayAttr - > arraySize  = =  0 )  { 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        copyTensorArrayAttribute ( inputs [ 1 ] ,  outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  tpParam  =  op - > main_as_TensorArray ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  concatAxis  =  tpParam - > axis ( ) ,  newAxis  =  tpParam - > new_axis ( ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > setType ( op - > main_as_TensorArray ( ) - > T ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-06-10 10:39:50 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        const  auto &  elemShapes  =  inDes - > tensorArrayAttr - > elemShape ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outShape  =  elemShapes [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        bool  valid  =  true ;  // avoid use MNN_ASSERT because it's no-op in release mode
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  ( int  i  =  1 ;  valid  & &  ( i  <  elemShapes . size ( ) ) ;  + + i )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            auto  elemShape  =  elemShapes [ inDes - > tensorArrayAttr - > isIdenticalShape  ?  0  :  i ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            valid  & =  ( outShape . size ( )  = =  elemShape . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  ( newAxis )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                valid  & =  ( std : : equal ( outShape . begin ( ) ,  outShape . end ( ) ,  elemShape . begin ( ) ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                valid  & =  ( std : : equal ( outShape . begin ( ) ,  outShape . begin ( )  +  concatAxis ,  elemShape . begin ( ) ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                valid  & =  ( std : : equal ( outShape . begin ( )  +  concatAxis  +  1 ,  outShape . end ( ) ,  elemShape . begin ( )  +  concatAxis  +  1 ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                outShape [ concatAxis ]  + =  elemShape [ concatAxis ] ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( ! valid )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ERROR ( " Invalid input, elements in seq have different shape [new_axis=true need same shape, new_axis=false need same shape except concat_axis dim] \n " ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( newAxis )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outShape . insert ( outShape . begin ( )  +  concatAxis ,  inDes - > tensorArrayAttr - > arraySize ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > buffer ( ) . dimensions  =  outShape . size ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  ( int  i  =  0 ;  i  <  outShape . size ( ) ;  + + i )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outputs [ 0 ] - > setLength ( i ,  outShape [ i ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE ( TensorArrayConcatComputer ,  OpType_TensorArrayConcat ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayInsert ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayInsertComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, position, value, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: flow_out
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 4  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes   =  TensorUtils : : getDescribe ( inputs [ 3 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( TensorUtils : : getDescribe ( inputs [ 2 ] ) - > dimensionFormat  ! =  inDes - > dimensionFormat )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( inDes - > tensorArrayAttr - > isDynamicSize ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-06-10 10:39:50 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        copyTensorArrayAttribute ( inputs [ 3 ] ,  outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outSeq  =  TensorUtils : : getDescribe ( outputs [ 0 ] ) - > tensorArrayAttr ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > buffer ( ) . type  =  inputs [ 3 ] - > buffer ( ) . type ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  inSeqSize  =  inDes - > tensorArrayAttr - > arraySize ,  insertIndex  =  inputs [ 1 ] - > host < int32_t > ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( insertIndex  > =  - inSeqSize  & &  insertIndex  < =  inSeqSize ) ;  // [-n, n]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        insertIndex  + =  ( insertIndex  <  0  ?  inSeqSize  :  0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // update arraySize
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outSeq - > arraySize  + =  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // update elemShape
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  insertShape  =  inputs [ 2 ] - > shape ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto &  outSeqShapes  =  outSeq - > elemShape ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( outSeq - > isIdenticalShape  & &  ! outSeqShapes . empty ( ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( std : : equal ( insertShape . begin ( ) ,  insertShape . end ( ) ,  outSeqShapes [ 0 ] . begin ( ) ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        }  else  { 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            outSeqShapes . insert ( outSeqShapes . begin ( )  +  insertIndex ,  insertShape ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        updateTensorArrayDims ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayInsertComputer ,  OpType_TensorArrayInsert ,  { 1 } ) ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// ============================ TensorArrayErase ============================
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TensorArrayEraseComputer  :  public  SizeComputer  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // inputs : handle, position, flow_in
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // outputs: flow_out
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    virtual  bool  onComputeSize ( const  MNN : : Op *  op ,  const  std : : vector < Tensor * > &  inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               const  std : : vector < Tensor * > &  outputs )  const  override  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( 3  = =  inputs . size ( )  & &  1  = =  outputs . size ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  inDes   =  TensorUtils : : getDescribe ( inputs [ 2 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( inDes - > tensorArrayAttr  = =  nullptr )  { 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            MNN_ASSERT ( false ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            return  false ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        MNN_ASSERT ( inDes - > tensorArrayAttr - > isDynamicSize ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2022-06-10 10:39:50 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        copyTensorArrayAttribute ( inputs [ 2 ] ,  outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        auto  outSeq  =  TensorUtils : : getDescribe ( outputs [ 0 ] ) - > tensorArrayAttr ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outputs [ 0 ] - > buffer ( ) . type  =  inputs [ 2 ] - > buffer ( ) . type ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        int  inSeqSize  =  outSeq - > arraySize ,  eraseIndex  =  inputs [ 1 ] - > host < int32_t > ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        MNN_ASSERT ( eraseIndex  > =  - inSeqSize  & &  eraseIndex  <  inSeqSize ) ;  // [-n, n-1]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        eraseIndex  + =  ( eraseIndex  <  0  ?  inSeqSize  :  0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // update arraySize
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        outSeq - > arraySize  - =  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        // update elemShape
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ( ! outSeq - > isIdenticalShape )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            outSeq - > elemShape . erase ( outSeq - > elemShape . begin ( )  +  eraseIndex ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        updateTensorArrayDims ( outputs [ 0 ] ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
									
										
										
										
											2022-01-04 10:50:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								REGISTER_SHAPE_INPUTS ( TensorArrayEraseComputer ,  OpType_TensorArrayErase ,  { 1 } ) ;  
						 
					
						
							
								
									
										
										
										
											2021-01-06 16:29:37 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								}  // namespace MNN