| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  GeometryPermute.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2020/04/03.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-18 19:11:50 +08:00
										 |  |  | #include <algorithm>
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "geometry/GeometryComputer.hpp"
 | 
					
						
							|  |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | class GeometryPermute : public GeometryComputer { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                            Context& context, CommandBuffer& res) const override { | 
					
						
							| 
									
										
										
										
											2024-05-11 19:17:02 +08:00
										 |  |  |         int dims = inputs[0]->buffer().dimensions; | 
					
						
							|  |  |  |         int neworder[MNN_MAX_TENSOR_DIM]; | 
					
						
							|  |  |  |         // get neworder
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         if (op->type() == OpType_Permute) { | 
					
						
							|  |  |  |             auto shapeValue = op->main_as_Permute()->dims(); | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |             if (nullptr != shapeValue) { | 
					
						
							| 
									
										
										
										
											2024-05-11 19:17:02 +08:00
										 |  |  |                 for (int i = 0; i < dims; ++i) { | 
					
						
							|  |  |  |                     neworder[i] = shapeValue->data()[i]; | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } else { | 
					
						
							| 
									
										
										
										
											2024-05-11 19:17:02 +08:00
										 |  |  |                 for (int i = 0; i < dims; ++i) { | 
					
						
							|  |  |  |                     neworder[i] = dims - i - 1; | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                 } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } else if (op->type() == OpType_Transpose) { | 
					
						
							| 
									
										
										
										
											2024-05-11 19:17:02 +08:00
										 |  |  |             MNN_ASSERT(inputs.size() > 1); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             auto shapeValue = inputs[1]->host<int32_t>(); | 
					
						
							| 
									
										
										
										
											2024-05-11 19:17:02 +08:00
										 |  |  |             for (int i = 0; i < dims; ++i) { | 
					
						
							|  |  |  |                 neworder[i] = shapeValue[i]; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-05-11 19:17:02 +08:00
										 |  |  |         return GeometryComputer::ComputePermuteRegion(inputs[0], outputs[0], neworder, dims); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | static void _create() { | 
					
						
							|  |  |  |     std::shared_ptr<GeometryComputer> comp(new GeometryPermute); | 
					
						
							|  |  |  |     GeometryComputer::registerGeometryComputer(comp, {OpType_Transpose, OpType_Permute}); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_GEOMETRY(GeometryPermute, _create); | 
					
						
							|  |  |  | }; // namespace MNN
 |