| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  ConvertUtils.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2020/04/03.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "ConvertUtils.hpp"
 | 
					
						
							|  |  |  | #include "core/OpCommonUtils.hpp"
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | bool ConvertUtils::compute(Tensor* input, Tensor* output, CommandBuffer& res) { | 
					
						
							|  |  |  |     auto inputDes     = TensorUtils::getDescribe(input); | 
					
						
							|  |  |  |     auto outputDes    = TensorUtils::getDescribe(output); | 
					
						
							|  |  |  |     auto inputFormat  = inputDes->dimensionFormat; | 
					
						
							|  |  |  |     auto outputFormat = outputDes->dimensionFormat; | 
					
						
							|  |  |  |     if (MNN_DATA_FORMAT_NC4HW4 == inputFormat) { | 
					
						
							|  |  |  |         inputFormat = MNN_DATA_FORMAT_NCHW; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (MNN_DATA_FORMAT_NC4HW4 == outputFormat) { | 
					
						
							|  |  |  |         outputFormat = MNN_DATA_FORMAT_NCHW; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     std::vector<Tensor::InsideDescribe::Region> inputSlice = {TensorUtils::makeFullSlice(input)}; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     if (inputFormat == outputFormat || 2 == input->dimensions()) { | 
					
						
							|  |  |  |         // No need for treat for NCWH <-> NC4HW4
 | 
					
						
							|  |  |  |         outputDes->regions    = std::move(inputSlice); | 
					
						
							|  |  |  |         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     // NHWC <-> NC4HW4: Turn NHWC to NCHW
 | 
					
						
							|  |  |  |     // TODO for multi input can find better way to compute new slice
 | 
					
						
							|  |  |  |     MNN_ASSERT(4 == input->dimensions()); | 
					
						
							|  |  |  |     auto inside  = input->width() * input->height(); | 
					
						
							|  |  |  |     auto axis    = input->channel(); | 
					
						
							|  |  |  |     auto outside = input->batch(); | 
					
						
							|  |  |  |     auto swap    = [](Tensor::InsideDescribe::Region& inp) { | 
					
						
							|  |  |  |         auto tempStride   = inp.src.stride[2]; | 
					
						
							|  |  |  |         inp.src.stride[2] = inp.src.stride[1]; | 
					
						
							|  |  |  |         inp.src.stride[1] = tempStride; | 
					
						
							|  |  |  |         auto tempSize     = inp.size[2]; | 
					
						
							|  |  |  |         inp.size[2]       = inp.size[1]; | 
					
						
							|  |  |  |         inp.size[1]       = tempSize; | 
					
						
							|  |  |  |         inp.dst.stride[2] = 1; | 
					
						
							|  |  |  |         inp.dst.stride[1] = inp.size[2]; | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     if (inputSlice.size() == 1) { | 
					
						
							|  |  |  |         auto& inp       = inputSlice[0]; | 
					
						
							|  |  |  |         bool canReshape = false; | 
					
						
							|  |  |  |         if (inputFormat == MNN_DATA_FORMAT_NCHW) { | 
					
						
							|  |  |  |             canReshape = TensorUtils::reshapeSlice(inp, outside, inside, axis); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             canReshape = TensorUtils::reshapeSlice(inp, outside, axis, inside); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (canReshape) { | 
					
						
							|  |  |  |             swap(inp); | 
					
						
							|  |  |  |             outputDes->regions    = std::move(inputSlice); | 
					
						
							|  |  |  |             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  |             return true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     auto slice = TensorUtils::makeFullSlice(input); | 
					
						
							|  |  |  |     if (inputFormat == MNN_DATA_FORMAT_NCHW) { | 
					
						
							|  |  |  |         TensorUtils::reshapeSlice(slice, outside, inside, axis); | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |         TensorUtils::reshapeSlice(slice, outside, axis, inside); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     swap(slice); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     outputDes->regions    = {slice}; | 
					
						
							|  |  |  |     outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | void ConvertUtils::broadcastto(Tensor* input, Tensor* output, bool forward) { | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  |      | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     auto outputDes        = TensorUtils::getDescribe(output); | 
					
						
							|  |  |  |     outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; | 
					
						
							| 
									
										
										
										
											2022-08-12 10:30:48 +08:00
										 |  |  |     if (TensorUtils::getRawSize(input) == TensorUtils::getRawSize(output)) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         // Just Copy Tensor
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         outputDes->regions = {TensorUtils::makeFullSlice(input)}; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     // if forward ( tf select broadcast )
 | 
					
						
							|  |  |  |     if (forward) { | 
					
						
							|  |  |  |         MNN_ASSERT(input->dimensions() == 1 && output->dimensions() > 1); | 
					
						
							|  |  |  |         MNN_ASSERT(input->length(0) == output->length(0)); | 
					
						
							|  |  |  |         int srcSize = input->length(0); | 
					
						
							|  |  |  |         int multipler = output->length(1); | 
					
						
							|  |  |  |         for (int i = 2; i < output->dimensions(); i++) { | 
					
						
							|  |  |  |             multipler *= output->length(i); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // [srcSize] -> [srcSize, multipler]
 | 
					
						
							|  |  |  |         outputDes->regions.resize(1); | 
					
						
							|  |  |  |         auto& reg = outputDes->regions[0]; | 
					
						
							|  |  |  |         reg.size[0] = 1; | 
					
						
							|  |  |  |         reg.size[1] = srcSize; | 
					
						
							|  |  |  |         reg.size[2] = multipler; | 
					
						
							|  |  |  |         reg.src.offset = 0; | 
					
						
							|  |  |  |         reg.src.stride[0] = srcSize; | 
					
						
							|  |  |  |         reg.src.stride[1] = 1; | 
					
						
							|  |  |  |         reg.src.stride[2] = 0; | 
					
						
							|  |  |  |         reg.dst.offset = 0; | 
					
						
							|  |  |  |         reg.dst.stride[0] = srcSize * multipler; | 
					
						
							|  |  |  |         reg.dst.stride[1] = multipler; | 
					
						
							|  |  |  |         reg.dst.stride[2] = 1; | 
					
						
							|  |  |  |         reg.origin = input; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     int32_t inputShape[MNN_MAX_TENSOR_DIM]; | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  |     int32_t outputShape[MNN_MAX_TENSOR_DIM]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     auto outputDim = output->dimensions(); | 
					
						
							|  |  |  |     for (int i=0; i<outputDim; ++i) { | 
					
						
							|  |  |  |         inputShape[i] = 1; | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  |         outputShape[i] = output->length(i); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     int offset = outputDim - input->dimensions(); | 
					
						
							|  |  |  |     for (int i = 0; i < input->dimensions(); ++i) { | 
					
						
							|  |  |  |         inputShape[i + offset] = input->length(i); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // Squeeze consecutive 1 dimension
 | 
					
						
							|  |  |  |     while(outputDim >= 2) { | 
					
						
							|  |  |  |         bool canFuse = false; | 
					
						
							|  |  |  |         for(int i=0; i<outputDim-1; ++i) { | 
					
						
							|  |  |  |             if(inputShape[i] == 1 && inputShape[i+1] == 1) { | 
					
						
							|  |  |  |                 for(int j=i+1; j<outputDim; j++) { | 
					
						
							|  |  |  |                     inputShape[j] = inputShape[j+1]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 outputShape[i] *= outputShape[i+1]; | 
					
						
							|  |  |  |                 for(int j=i+1; j<outputDim; j++) { | 
					
						
							|  |  |  |                     outputShape[j] = outputShape[j+1]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 outputDim--; | 
					
						
							|  |  |  |                 i--; | 
					
						
							|  |  |  |                 canFuse = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if(!canFuse) { | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     // Compute Strides
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     int sepInputShapeSize = 0; | 
					
						
							|  |  |  |     int sepOutputShapeSize = 0; | 
					
						
							|  |  |  |     int sepInputShape[MNN_MAX_TENSOR_DIM]; | 
					
						
							|  |  |  |     int sepOutputShape[MNN_MAX_TENSOR_DIM]; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     int currentInput  = 1; | 
					
						
							|  |  |  |     int currentOutput = 1; | 
					
						
							|  |  |  |     for (int i = 0; i < outputDim; ++i) { | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  |         if (inputShape[i] != outputShape[i]) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             if (1 < currentOutput) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |                 sepInputShape[sepInputShapeSize++] = currentInput; | 
					
						
							|  |  |  |                 sepOutputShape[sepOutputShapeSize++] = currentOutput; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             sepInputShape[sepInputShapeSize++] = (inputShape[i]); | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  |             sepOutputShape[sepOutputShapeSize++] = (outputShape[i]); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             currentInput  = 1; | 
					
						
							|  |  |  |             currentOutput = 1; | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             currentInput *= inputShape[i]; | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  |             currentOutput *= outputShape[i]; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (currentOutput != 1 || currentInput != 1) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         sepInputShape[sepInputShapeSize++] = (currentInput); | 
					
						
							|  |  |  |         sepOutputShape[sepOutputShapeSize++] = (currentOutput); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     int seperateOutputStrides[MNN_MAX_TENSOR_DIM]; | 
					
						
							|  |  |  |     int seperateInputStrides[MNN_MAX_TENSOR_DIM]; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     OpCommonUtils::computeStride(seperateOutputStrides, sepOutputShape, sepOutputShapeSize); | 
					
						
							|  |  |  |     OpCommonUtils::computeStride(seperateInputStrides, sepInputShape, sepInputShapeSize); | 
					
						
							|  |  |  |     for (int i = 0; i < sepInputShapeSize; ++i) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         if (1 == sepInputShape[i]) { | 
					
						
							|  |  |  |             seperateInputStrides[i] = 0; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Split region by size, use stride to determine src and dst mapping
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     int remainDimSize = sepInputShapeSize > 3 ? (int)sepInputShapeSize - 3 : 0; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     int remainStride[MNN_MAX_TENSOR_DIM]; | 
					
						
							|  |  |  |     int remainSize = OpCommonUtils::computeStride(remainStride, sepOutputShape, remainDimSize); | 
					
						
							| 
									
										
										
										
											2022-09-09 17:21:11 +08:00
										 |  |  |     outputDes->regions.clear(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     outputDes->regions.resize(remainSize); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     int cords[MNN_MAX_TENSOR_DIM]; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     for (int index = 0; index < remainSize; ++index) { | 
					
						
							|  |  |  |         OpCommonUtils::unravelIndexHelper(cords, remainStride, remainDimSize, index); | 
					
						
							|  |  |  |         auto& reg = outputDes->regions[index]; | 
					
						
							|  |  |  |         for (int i = 0; i < remainDimSize; ++i) { | 
					
						
							|  |  |  |             reg.src.offset += (cords[i] * seperateInputStrides[i]); | 
					
						
							|  |  |  |             reg.dst.offset += (cords[i] * seperateOutputStrides[i]); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         reg.origin = input; | 
					
						
							|  |  |  |         for (int i = 0; i < 3; ++i) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             auto match = (int)sepOutputShapeSize - i - 1; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             if (match < 0) { | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             reg.size[3 - i - 1]       = sepOutputShape[match]; | 
					
						
							|  |  |  |             reg.src.stride[3 - i - 1] = seperateInputStrides[match]; | 
					
						
							|  |  |  |             reg.dst.stride[3 - i - 1] = seperateOutputStrides[match]; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-11-08 17:05:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 |