mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			99 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			99 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  PermuteExecution.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2019/02/28.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#include "PermuteExecution.hpp"
 | 
						|
#include <Macro.h>
 | 
						|
#include "TensorUtils.hpp"
 | 
						|
#include "core/OpenCLBackend.hpp"
 | 
						|
 | 
						|
namespace MNN {
 | 
						|
namespace OpenCL {
 | 
						|
 | 
						|
PermuteExecution::PermuteExecution(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend)
 | 
						|
    : CommonExecution(backend) {
 | 
						|
    auto shape = op->main_as_Permute()->dims();
 | 
						|
    // FIXME, support less than 4
 | 
						|
    MNN_ASSERT(shape->size() == 4);
 | 
						|
    mDims.resize(4);
 | 
						|
    for (int i = 0; i < shape->size(); ++i) {
 | 
						|
        auto dim   = shape->data()[i];
 | 
						|
        mDims[dim] = i;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
ErrorCode PermuteExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
 | 
						|
    Tensor *input  = inputs[0];
 | 
						|
    Tensor *output = outputs[0];
 | 
						|
    // FIXME, support nhwc format
 | 
						|
    MNN_ASSERT(input->getDimensionType() != Tensor::TENSORFLOW);
 | 
						|
 | 
						|
    std::vector<int> inputShape  = tensorShapeFormat(input);
 | 
						|
    std::vector<int> outputShape = tensorShapeFormat(output);
 | 
						|
 | 
						|
    auto runTime        = ((OpenCLBackend *)backend())->getOpenCLRuntime();
 | 
						|
    auto bufferPool     = ((OpenCLBackend *)backend())->getBufferPool();
 | 
						|
    auto bufferUnitSize = runTime->isSupportedFP16() ? sizeof(int16_t) : sizeof(float);
 | 
						|
    auto bufferSize0    = UP_DIV(outputShape[3], 4) * 4 * outputShape[0] * outputShape[1] * outputShape[2];
 | 
						|
    auto bufferSize1    = UP_DIV(inputShape[3], 4) * 4 * inputShape[0] * inputShape[1] * inputShape[2];
 | 
						|
    mTempInput          = bufferPool->alloc(std::max(bufferSize0, bufferSize1) * bufferUnitSize);
 | 
						|
    bufferPool->recycle(mTempInput);
 | 
						|
 | 
						|
    mUnits.resize(2);
 | 
						|
    int offset[] = {0, 0, 0, 0};
 | 
						|
 | 
						|
    // NCHW's stride, use the stride of nhwc
 | 
						|
    int outputStride[] = {outputShape[1] * outputShape[2] * outputShape[3], 1, outputShape[2] * outputShape[3],
 | 
						|
                          outputShape[3]};
 | 
						|
    int permuteInputStride[4];
 | 
						|
    for (int i = 0; i < mDims.size(); ++i) {
 | 
						|
        permuteInputStride[i] = outputStride[mDims[i]];
 | 
						|
    }
 | 
						|
    int inputWH[]  = {inputShape[2], inputShape[1]};
 | 
						|
    int outputWH[] = {outputShape[2], outputShape[1]};
 | 
						|
    {
 | 
						|
        int region[] = {inputShape[0], UP_DIV(inputShape[3], 4), inputShape[1], inputShape[2]};
 | 
						|
        uint32_t gw0 = region[1] * region[3];
 | 
						|
        uint32_t gw1 = region[0] * region[2];
 | 
						|
        auto &unit   = mUnits[0];
 | 
						|
        unit.kernel  = runTime->buildKernel("blitBuffer", "blitImageToBuffer", {});
 | 
						|
        unit.kernel.setArg(0, openCLImage(inputs[0]));
 | 
						|
        unit.kernel.setArg(1, *mTempInput);
 | 
						|
        unit.kernel.setArg(2, offset);
 | 
						|
        unit.kernel.setArg(3, offset);
 | 
						|
        unit.kernel.setArg(4, region);
 | 
						|
        unit.kernel.setArg(5, inputWH);
 | 
						|
        unit.kernel.setArg(6, 4 * sizeof(int), permuteInputStride);
 | 
						|
        unit.kernel.setArg(7, 4 * sizeof(int), inputShape.data());
 | 
						|
        unit.localWorkSize  = {16, 16};
 | 
						|
        unit.globalWorkSize = {UP_DIV(gw0, 16) * 16, UP_DIV(gw1, 16) * 16};
 | 
						|
    }
 | 
						|
    {
 | 
						|
        int region[] = {outputShape[0], UP_DIV(outputShape[3], 4), outputShape[1], outputShape[2]};
 | 
						|
 | 
						|
        auto &unit          = mUnits[1];
 | 
						|
        unit.kernel         = runTime->buildKernel("blitBuffer", "blitBufferToImage", {});
 | 
						|
        unit.localWorkSize  = {16, 16};
 | 
						|
        unit.globalWorkSize = {(uint32_t)UP_DIV(region[3] * region[1], 16) * 16,
 | 
						|
                               (uint32_t)UP_DIV(region[2] * region[0], 16) * 16};
 | 
						|
        unit.kernel.setArg(0, *mTempInput);
 | 
						|
        unit.kernel.setArg(1, openCLImage(output));
 | 
						|
        unit.kernel.setArg(2, offset);
 | 
						|
        unit.kernel.setArg(3, offset);
 | 
						|
        unit.kernel.setArg(4, region);
 | 
						|
        unit.kernel.setArg(5, outputStride);
 | 
						|
        unit.kernel.setArg(6, outputWH);
 | 
						|
        unit.kernel.setArg(7, outputWH);
 | 
						|
    }
 | 
						|
    return NO_ERROR;
 | 
						|
}
 | 
						|
 | 
						|
OpenCLCreatorRegister<TypedCreator<PermuteExecution>> __permute_op(OpType_Permute);
 | 
						|
 | 
						|
} // namespace OpenCL
 | 
						|
} // namespace MNN
 |