mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			167 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  ConvertUtils.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2020/04/03.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include "ConvertUtils.hpp"
 | |
| #include "core/OpCommonUtils.hpp"
 | |
| namespace MNN {
 | |
| bool ConvertUtils::compute(Tensor* input, Tensor* output, CommandBuffer& res) {
 | |
|     auto inputDes     = TensorUtils::getDescribe(input);
 | |
|     auto outputDes    = TensorUtils::getDescribe(output);
 | |
|     auto inputFormat  = inputDes->dimensionFormat;
 | |
|     auto outputFormat = outputDes->dimensionFormat;
 | |
|     if (MNN_DATA_FORMAT_NC4HW4 == inputFormat) {
 | |
|         inputFormat = MNN_DATA_FORMAT_NCHW;
 | |
|     }
 | |
|     if (MNN_DATA_FORMAT_NC4HW4 == outputFormat) {
 | |
|         outputFormat = MNN_DATA_FORMAT_NCHW;
 | |
|     }
 | |
|     auto inputSlice = inputDes->regions;
 | |
|     MNN_ASSERT(input->dimensions() >= 1);
 | |
|     MNN_ASSERT(output->dimensions() == input->dimensions());
 | |
|     if (inputSlice.empty()) {
 | |
|         inputSlice.resize(1);
 | |
|         // Create Full Refence
 | |
|         inputSlice[0] = TensorUtils::makeFullSlice(input);
 | |
|     }
 | |
|     if (inputFormat == outputFormat || 2 == input->dimensions()) {
 | |
|         // No need for treat for NCWH <-> NC4HW4
 | |
|         outputDes->regions    = std::move(inputSlice);
 | |
|         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|         return true;
 | |
|     }
 | |
|     // NHWC <-> NC4HW4: Turn NHWC to NCHW
 | |
|     // TODO for multi input can find better way to compute new slice
 | |
|     MNN_ASSERT(4 == input->dimensions());
 | |
|     auto inside  = input->width() * input->height();
 | |
|     auto axis    = input->channel();
 | |
|     auto outside = input->batch();
 | |
|     auto swap    = [](Tensor::InsideDescribe::Region& inp) {
 | |
|         auto tempStride   = inp.src.stride[2];
 | |
|         inp.src.stride[2] = inp.src.stride[1];
 | |
|         inp.src.stride[1] = tempStride;
 | |
|         auto tempSize     = inp.size[2];
 | |
|         inp.size[2]       = inp.size[1];
 | |
|         inp.size[1]       = tempSize;
 | |
|         inp.dst.stride[2] = 1;
 | |
|         inp.dst.stride[1] = inp.size[2];
 | |
|     };
 | |
|     if (inputSlice.size() == 1) {
 | |
|         auto& inp       = inputSlice[0];
 | |
|         bool canReshape = false;
 | |
|         if (inputFormat == MNN_DATA_FORMAT_NCHW) {
 | |
|             canReshape = TensorUtils::reshapeSlice(inp, outside, inside, axis);
 | |
|         } else {
 | |
|             canReshape = TensorUtils::reshapeSlice(inp, outside, axis, inside);
 | |
|         }
 | |
|         if (canReshape) {
 | |
|             swap(inp);
 | |
|             outputDes->regions    = std::move(inputSlice);
 | |
|             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             return true;
 | |
|         }
 | |
|     }
 | |
|     auto slice = TensorUtils::makeFullSlice(input);
 | |
|     if (inputFormat == MNN_DATA_FORMAT_NCHW) {
 | |
|         TensorUtils::reshapeSlice(slice, outside, inside, axis);
 | |
|     } else {
 | |
|         TensorUtils::reshapeSlice(slice, outside, axis, inside);
 | |
|     }
 | |
|     swap(slice);
 | |
| 
 | |
|     outputDes->regions    = {slice};
 | |
|     outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
| 
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| void ConvertUtils::broadcastto(Tensor* input, Tensor* output) {
 | |
|     auto inputDes         = TensorUtils::getDescribe(input);
 | |
|     auto outputDes        = TensorUtils::getDescribe(output);
 | |
|     outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|     if (input->elementSize() == output->elementSize()) {
 | |
|         // Just Copy Tensor
 | |
|         auto inputSlice = inputDes->regions;
 | |
|         if (inputSlice.empty()) {
 | |
|             // Create Full Refence
 | |
|             Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
 | |
|             inputSlice.emplace_back(std::move(totalSlice));
 | |
|         }
 | |
|         outputDes->regions = std::move(inputSlice);
 | |
|         return;
 | |
|     }
 | |
|     int32_t inputShape[MNN_MAX_TENSOR_DIM];
 | |
|     auto outputDim = output->dimensions();
 | |
|     for (int i=0; i<outputDim; ++i) {
 | |
|         inputShape[i] = 1;
 | |
|     }
 | |
|     int offset = outputDim - input->dimensions();
 | |
|     for (int i = 0; i < input->dimensions(); ++i) {
 | |
|         inputShape[i + offset] = input->length(i);
 | |
|     }
 | |
|     // Compute Strides
 | |
|     std::vector<int> sepInputShape;
 | |
|     std::vector<int> sepOutputShape;
 | |
|     int currentInput  = 1;
 | |
|     int currentOutput = 1;
 | |
|     for (int i = 0; i < outputDim; ++i) {
 | |
|         if (inputShape[i] != output->length(i)) {
 | |
|             if (1 < currentOutput) {
 | |
|                 sepInputShape.emplace_back(currentInput);
 | |
|                 sepOutputShape.emplace_back(currentOutput);
 | |
|             }
 | |
|             sepInputShape.emplace_back(inputShape[i]);
 | |
|             sepOutputShape.emplace_back(output->length(i));
 | |
|             currentInput  = 1;
 | |
|             currentOutput = 1;
 | |
|         } else {
 | |
|             currentInput *= inputShape[i];
 | |
|             currentOutput *= output->length(i);
 | |
|         }
 | |
|     }
 | |
|     if (currentOutput != 1 || currentInput != 1) {
 | |
|         sepInputShape.emplace_back(currentInput);
 | |
|         sepOutputShape.emplace_back(currentOutput);
 | |
|     }
 | |
|     int seperateOutputStrides[MNN_MAX_TENSOR_DIM];
 | |
|     int seperateInputStrides[MNN_MAX_TENSOR_DIM];
 | |
|     OpCommonUtils::computeStride(seperateOutputStrides, sepOutputShape.data(), sepOutputShape.size());
 | |
|     OpCommonUtils::computeStride(seperateInputStrides, sepInputShape.data(), sepInputShape.size());
 | |
|     for (int i = 0; i < sepInputShape.size(); ++i) {
 | |
|         if (1 == sepInputShape[i]) {
 | |
|             seperateInputStrides[i] = 0;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     // Split region by size, use stride to determine src and dst mapping
 | |
|     int remainDimSize = sepInputShape.size() > 3 ? (int)sepInputShape.size() - 3 : 0;
 | |
|     std::vector<int> remainStride(remainDimSize + 1);
 | |
|     int remainSize = OpCommonUtils::computeStride(remainStride.data(), sepOutputShape.data(), remainDimSize);
 | |
|     outputDes->regions.resize(remainSize);
 | |
|     std::vector<int> cords(remainDimSize + 1);
 | |
|     for (int index = 0; index < remainSize; ++index) {
 | |
|         OpCommonUtils::unravelIndexHelper(cords, remainStride, remainDimSize, index);
 | |
|         auto& reg = outputDes->regions[index];
 | |
|         for (int i = 0; i < remainDimSize; ++i) {
 | |
|             reg.src.offset += (cords[i] * seperateInputStrides[i]);
 | |
|             reg.dst.offset += (cords[i] * seperateOutputStrides[i]);
 | |
|         }
 | |
|         reg.origin = input;
 | |
|         for (int i = 0; i < 3; ++i) {
 | |
|             auto match = (int)sepOutputShape.size() - i - 1;
 | |
|             if (match < 0) {
 | |
|                 continue;
 | |
|             }
 | |
|             reg.size[3 - i - 1]       = sepOutputShape[match];
 | |
|             reg.src.stride[3 - i - 1] = seperateInputStrides[match];
 | |
|             reg.dst.stride[3 - i - 1] = seperateOutputStrides[match];
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| } // namespace MNN
 |