| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ShapeSlice.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/10.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "shape/SizeComputer.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  | #include <algorithm>
 | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  | #include <numeric>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SliceComputer : public SizeComputer { | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         //MNN_ASSERT(1 == inputs.size());
 | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |         auto outputSize = (int)outputs.size(); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         auto slice = op->main_as_Slice(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto& input = inputs[0]->buffer(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int axis = slice->axis(); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |         if (axis < 0) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             axis += input.dimensions; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  |         /*
 | 
					
						
							|  |  |  |          If we want split (2, 10) => (2, 3) + (2, 5) + (2, 2), slicePoints is | 
					
						
							|  |  |  |          1. [3, 8, 10] when slice->sourceType = NetSource_CAFFE | 
					
						
							|  |  |  |          2. [3, 5, 2] otherwise | 
					
						
							|  |  |  |          */ | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         if (MNN::NetSource_CAFFE == slice->sourceType()) { | 
					
						
							|  |  |  |             // caffe Slice
 | 
					
						
							|  |  |  |             int previous = 0; | 
					
						
							|  |  |  |             for (int i = 0; i < slice->slicePoints()->size(); ++i) { | 
					
						
							|  |  |  |                 int sliceIndex    = slice->slicePoints()->data()[i]; | 
					
						
							|  |  |  |                 auto& output      = outputs[i]->buffer(); | 
					
						
							|  |  |  |                 output.dimensions = input.dimensions; | 
					
						
							|  |  |  |                 ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  |                 output.type             = input.type; | 
					
						
							|  |  |  |                 output.dim[axis].extent = sliceIndex - previous; | 
					
						
							|  |  |  |                 previous                = sliceIndex; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // Compute Last
 | 
					
						
							|  |  |  |             auto& output = outputs[outputs.size() - 1]->buffer(); | 
					
						
							| 
									
										
										
										
											2020-04-27 09:46:13 +08:00
										 |  |  |             output.dimensions = input.dimensions; | 
					
						
							|  |  |  |             output.type             = input.type; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             output.dim[axis].extent = input.dim[axis].extent - previous; | 
					
						
							|  |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |             // tensorflow/Torch Split
 | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  |             if (inputs.size() == 1 && (nullptr == slice->slicePoints() || 1 == slice->slicePoints()->size())) { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |                 // slicePoint size is 1:
 | 
					
						
							|  |  |  |                 // TF value is num_split, Torch value is split_size
 | 
					
						
							|  |  |  |                 int numSplits = outputSize, | 
					
						
							|  |  |  |                     splitDim = input.dim[axis].extent / numSplits; | 
					
						
							|  |  |  |                 if (MNN::NetSource_TORCH == slice->sourceType()) { | 
					
						
							|  |  |  |                     if (nullptr != slice->slicePoints()) { | 
					
						
							|  |  |  |                         splitDim = slice->slicePoints()->data()[0]; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     numSplits = input.dim[axis].extent / splitDim; | 
					
						
							|  |  |  |                 } else if (MNN::NetSource_TENSORFLOW == slice->sourceType()) { | 
					
						
							|  |  |  |                     if (nullptr != slice->slicePoints() && slice->slicePoints()->data()[0] != outputSize) { | 
					
						
							|  |  |  |                         numSplits = slice->slicePoints()->data()[0]; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     MNN_ASSERT(0 == input.dim[axis].extent % numSplits); | 
					
						
							|  |  |  |                     splitDim = input.dim[axis].extent / numSplits; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |                 } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |                 for (int i = 0; i < outputSize; i++) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                     auto& output      = outputs[i]->buffer(); | 
					
						
							|  |  |  |                     output.dimensions = input.dimensions; | 
					
						
							|  |  |  |                     output.type       = input.type; | 
					
						
							|  |  |  |                     ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							|  |  |  |                     output.dim[axis].extent = splitDim; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } else { | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  |                 std::vector<int> slicePoints; | 
					
						
							|  |  |  |                 if (inputs.size() == 2) { | 
					
						
							|  |  |  |                     slicePoints.assign(inputs[1]->host<int>(), inputs[1]->host<int>() + inputs[1]->elementSize()); | 
					
						
							|  |  |  |                 } else if (slice->slicePoints() != nullptr) { | 
					
						
							|  |  |  |                     slicePoints.assign(slice->slicePoints()->begin(), slice->slicePoints()->end()); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 int totalLen = std::accumulate(slicePoints.begin(), slicePoints.end(), 0); | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |                 if (totalLen > inputs[0]->length(axis)) { | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  |                     MNN_ASSERT(false); | 
					
						
							|  |  |  |                     return false; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 int numberSplits = slicePoints.size(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |                 MNN_ASSERT(0 < numberSplits); | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |                 numberSplits = std::min(numberSplits, outputSize); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                 int determineTensorIndex = -1; | 
					
						
							|  |  |  |                 int maxSize              = 0; | 
					
						
							| 
									
										
										
										
											2019-11-15 14:22:45 +08:00
										 |  |  |                 for (int i = 0; i < numberSplits; i++) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                     auto& output      = outputs[i]->buffer(); | 
					
						
							|  |  |  |                     output.type       = input.type; | 
					
						
							|  |  |  |                     output.dimensions = input.dimensions; | 
					
						
							|  |  |  |                     ::memcpy(output.dim, input.dim, input.dimensions * sizeof(halide_dimension_t)); | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  |                     auto length = slicePoints[i]; | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |                     if (-1 != length) { | 
					
						
							|  |  |  |                         output.dim[axis].extent = length; | 
					
						
							|  |  |  |                         maxSize += length; | 
					
						
							|  |  |  |                     } else { | 
					
						
							|  |  |  |                         if (determineTensorIndex >= 0) { | 
					
						
							|  |  |  |                             // Don't support two -1 points
 | 
					
						
							|  |  |  |                             return false; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         determineTensorIndex = i; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 if (determineTensorIndex >= 0) { | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |                     auto& output            = outputs[determineTensorIndex]->buffer(); | 
					
						
							|  |  |  |                     output.dim[axis].extent = input.dim[axis].extent - maxSize; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |         for (int i=0; i<outputs.size(); ++i) { | 
					
						
							|  |  |  |             TensorUtils::getDescribe(outputs[i])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  | REGISTER_SHAPE_INPUTS(SliceComputer, OpType_Slice, {1}); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | } // namespace MNN
 |