| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ShapeSlice.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/10.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							|  |  |  | #include "core/SizeComputer.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  | #include <algorithm>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SliceComputer : public SizeComputer { | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         MNN_ASSERT(1 == inputs.size()); | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |         auto outputSize = (int)outputs.size(); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         auto slice = op->main_as_Slice(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto& input = inputs[0]->buffer(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int axis = slice->axis(); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |         if (axis < 0) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             axis += input.dimensions; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (MNN::NetSource_CAFFE == slice->sourceType()) { | 
					
						
							|  |  |  |             // caffe Slice
 | 
					
						
							|  |  |  |             int previous = 0; | 
					
						
							|  |  |  |             for (int i = 0; i < slice->slicePoints()->size(); ++i) { | 
					
						
							|  |  |  |                 int sliceIndex    = slice->slicePoints()->data()[i]; | 
					
						
							|  |  |  |                 auto& output      = outputs[i]->buffer(); | 
					
						
							|  |  |  |                 output.dimensions = input.dimensions; | 
					
						
							|  |  |  |                 ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  |                 output.type             = input.type; | 
					
						
							|  |  |  |                 output.dim[axis].extent = sliceIndex - previous; | 
					
						
							|  |  |  |                 previous                = sliceIndex; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // Compute Last
 | 
					
						
							|  |  |  |             auto& output = outputs[outputs.size() - 1]->buffer(); | 
					
						
							| 
									
										
										
										
											2020-04-27 09:46:13 +08:00
										 |  |  |             output.dimensions = input.dimensions; | 
					
						
							|  |  |  |             output.type             = input.type; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             output.dim[axis].extent = input.dim[axis].extent - previous; | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             // tensorflow Split
 | 
					
						
							|  |  |  |             if (1 == slice->slicePoints()->size()) { | 
					
						
							|  |  |  |                 // scalar
 | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |                 int numSplits = slice->slicePoints()->data()[0]; | 
					
						
							|  |  |  |                 numSplits = std::min(numSplits, outputSize); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 MNN_ASSERT(0 == input.dim[axis].extent % numSplits); | 
					
						
							|  |  |  |                 const int splitDim = input.dim[axis].extent / numSplits; | 
					
						
							|  |  |  |                 for (int i = 0; i < numSplits; i++) { | 
					
						
							|  |  |  |                     auto& output      = outputs[i]->buffer(); | 
					
						
							|  |  |  |                     output.dimensions = input.dimensions; | 
					
						
							|  |  |  |                     output.type       = input.type; | 
					
						
							|  |  |  |                     ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  |                     output.dim[axis].extent = splitDim; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 // one dimension tensor, ex: [5,30]=>[5,4]+[5,15]+[5,11], slicePoints is [4, 15, 11]
 | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |                 int numberSplits = slice->slicePoints()->size(); | 
					
						
							|  |  |  |                 numberSplits = std::min(numberSplits, outputSize); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                 int determineTensorIndex = -1; | 
					
						
							|  |  |  |                 int maxSize              = 0; | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |                 for (int i = 0; i < numberSplits; i++) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                     auto& output      = outputs[i]->buffer(); | 
					
						
							|  |  |  |                     output.type       = input.type; | 
					
						
							|  |  |  |                     output.dimensions = input.dimensions; | 
					
						
							|  |  |  |                     ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                     auto length = slice->slicePoints()->data()[i]; | 
					
						
							|  |  |  |                     if (-1 != length) { | 
					
						
							|  |  |  |                         output.dim[axis].extent = length; | 
					
						
							|  |  |  |                         maxSize += length; | 
					
						
							|  |  |  |                     } else { | 
					
						
							|  |  |  |                         if (determineTensorIndex >= 0) { | 
					
						
							|  |  |  |                             // Don't support two -1 points
 | 
					
						
							|  |  |  |                             return false; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         determineTensorIndex = i; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 if (determineTensorIndex >= 0) { | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |                     auto& output            = outputs[determineTensorIndex]->buffer(); | 
					
						
							|  |  |  |                     output.dim[axis].extent = input.dim[axis].extent - maxSize; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |         for (int i=0; i<outputs.size(); ++i) { | 
					
						
							|  |  |  |             TensorUtils::getDescribe(outputs[i])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_SHAPE(SliceComputer, OpType_Slice); | 
					
						
							|  |  |  | } // namespace MNN
 |