| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ShapeTile.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/10.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "shape/SizeComputer.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							|  |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TileComputer : public SizeComputer { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(2 == inputs.size()); | 
					
						
							|  |  |  |         MNN_ASSERT(1 == outputs.size()); | 
					
						
							|  |  |  |         auto& input    = inputs[0]->buffer(); | 
					
						
							|  |  |  |         auto multiples = inputs[1]; | 
					
						
							|  |  |  |         MNN_ASSERT(multiples->getType().code == halide_type_int); | 
					
						
							|  |  |  |         auto& output = outputs[0]->buffer(); | 
					
						
							|  |  |  |         // Expected multiples argument to be a 1-D vector of length input.dimensions
 | 
					
						
							|  |  |  |         MNN_ASSERT(1 == multiples->buffer().dimensions) | 
					
						
							|  |  |  |         MNN_ASSERT(input.dimensions == multiples->buffer().dim[0].extent); | 
					
						
							|  |  |  |         const int inputDims = input.dimensions; | 
					
						
							|  |  |  |         ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  |         output.dimensions = inputDims; | 
					
						
							|  |  |  |         output.type       = input.type; | 
					
						
							| 
									
										
										
										
											2019-07-25 13:36:35 +08:00
										 |  |  |          | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         for (int i = 0; i < inputDims; ++i) { | 
					
						
							|  |  |  |             output.dim[i].extent = input.dim[i].extent * multiples->host<int32_t>()[i]; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |         TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  | REGISTER_SHAPE_INPUTS(TileComputer, OpType_Tile, {1}); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 |