mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			361 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			361 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  WrapExecution.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2018/09/03.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <cmath>
 | |
| #include "core/WrapExecution.hpp"
 | |
| #include "core/TensorUtils.hpp"
 | |
| #include "core/OpCommonUtils.hpp"
 | |
| #include "core/Concurrency.h"
 | |
| #include "backend/cpu/CPUCast.hpp"
 | |
| #include "backend/cpu/compute/CommonOptFunction.h"
 | |
| 
 | |
| namespace MNN {
 | |
| bool WrapExecution::needWrap(const Tensor* input, Backend* curBackend) {
 | |
|     auto curType = curBackend ? curBackend->type() : MNN_FORWARD_CPU;
 | |
|     if (curType == MNN_FORWARD_NN) {
 | |
|         return false;
 | |
|     }
 | |
|     auto des = TensorUtils::getDescribe(input);
 | |
|     auto bn = des->backend;
 | |
|     MNNForwardType type = MNN_FORWARD_CPU;
 | |
|     int pack = 4;
 | |
|     int bytes = 4;
 | |
|     if (nullptr != bn) {
 | |
|         type = bn->type();
 | |
|         if (type == MNN_FORWARD_CPU_EXTENSION) {
 | |
|             auto core = static_cast<CPUBackend*>(bn)->functions();
 | |
|             pack = core->pack;
 | |
|             bytes = core->bytes;
 | |
|         }
 | |
|     }
 | |
|     if (type == curType) {
 | |
|         return false;;
 | |
|     }
 | |
|     bool srcCpu = (type == MNN_FORWARD_CPU_EXTENSION || type == MNN_FORWARD_CPU);
 | |
|     bool dstCpu = ((curType == MNN_FORWARD_CPU_EXTENSION) || (curType == MNN_FORWARD_CPU));
 | |
|     if (srcCpu && dstCpu) {
 | |
|         int curBytes = 4, curPack = 4;
 | |
|         if (curBackend) {
 | |
|             auto dstCore = static_cast<CPUBackend*>(curBackend)->functions();
 | |
|             curBytes = dstCore->bytes;
 | |
|             curPack = dstCore->pack;
 | |
|         }
 | |
|         if (curBytes == bytes) {
 | |
|             if (curPack == pack || des->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) {
 | |
|                 return false;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| WrapExecution::WrapExecution(Backend* CPUBackend, std::shared_ptr<Execution> execution, bool isStatic)
 | |
|     : Execution(execution->backend()), mCPUBackend(CPUBackend), mExecution(execution) {
 | |
|     mValid  = execution->valid();
 | |
|     mStatic = isStatic;
 | |
| }
 | |
| 
 | |
| Tensor* WrapExecution::_getCopyTensor(Tensor* inputTensor) {
 | |
|     auto dstBackend = mExecution->backend();
 | |
|     auto inputDes   = TensorUtils::getDescribe(inputTensor);
 | |
|     auto srcBackend = inputDes->backend;
 | |
|     if (nullptr == srcBackend) {
 | |
|         srcBackend = mCPUBackend;
 | |
|     }
 | |
|     // CPU -> CPU or XPU -> XPU
 | |
|     //if (srcBackend == dstBackend) {
 | |
|     if (srcBackend->type() == dstBackend->type()) {
 | |
|         return inputTensor;
 | |
|     }
 | |
|     auto iter = mInputMaps.find(inputTensor);
 | |
|     if (iter != mInputMaps.end()) {
 | |
|         return std::get<2>(iter->second).get();
 | |
|     }
 | |
|     // CPU -> XPU
 | |
|     if (srcBackend->type() == mCPUBackend->type()) {
 | |
|         std::shared_ptr<Tensor> wrapTensor(new Tensor);
 | |
|         TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
 | |
|         TensorUtils::adjustTensorForCompability(wrapTensor.get());
 | |
|         wrapTensor->buffer().type = inputTensor->buffer().type;
 | |
|         TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
 | |
|         mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(dstBackend, dstBackend, wrapTensor)));
 | |
|         return wrapTensor.get();
 | |
|     }
 | |
|     // XPU -> CPU
 | |
|     if (dstBackend->type() == mCPUBackend->type()) {
 | |
|         std::shared_ptr<Tensor> wrapTensor(new Tensor);
 | |
|         TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
 | |
|         wrapTensor->buffer().type = inputTensor->buffer().type;
 | |
|         TensorUtils::adjustTensorForCompability(wrapTensor.get());
 | |
|         TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
 | |
|         mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, wrapTensor)));
 | |
|         return wrapTensor.get();
 | |
|     }
 | |
|     // XPU -> CPU -> XPU'
 | |
|     std::shared_ptr<Tensor> midTensor(new Tensor);
 | |
|     std::shared_ptr<Tensor> wrapTensor(new Tensor);
 | |
|     TensorUtils::copyShape(inputTensor, midTensor.get(), true);
 | |
|     TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
 | |
|     TensorUtils::adjustTensorForCompability(wrapTensor.get());
 | |
|     TensorUtils::adjustTensorForCompability(midTensor.get());
 | |
|     TensorUtils::getDescribe(midTensor.get())->usage = TensorUtils::getDescribe(inputTensor)->usage;
 | |
|     TensorUtils::getDescribe(midTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
 | |
|     midTensor->buffer().type                         = inputTensor->buffer().type;
 | |
|     wrapTensor->buffer().type                        = inputTensor->buffer().type;
 | |
|     mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, midTensor)));
 | |
|     mInputMaps.insert(std::make_pair(midTensor.get(), std::make_tuple(dstBackend, dstBackend, wrapTensor)));
 | |
|     return wrapTensor.get();
 | |
| }
 | |
| 
 | |
| ErrorCode WrapExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
 | |
|     mWrapInputTensors.resize(inputs.size());
 | |
|     mInputMaps.clear();
 | |
| 
 | |
|     auto dstBackend = mExecution->backend();
 | |
|     for (int i = 0; i < inputs.size(); ++i) {
 | |
|         auto inputTensor = inputs[i];
 | |
|         auto des         = TensorUtils::getDescribe(inputTensor);
 | |
|         if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
 | |
|             MNN_ASSERT(inputs.size() == 1);
 | |
|             mWrapForRaster.reset(new Tensor);
 | |
|             TensorUtils::copyShape(inputTensor, mWrapForRaster.get(), true);
 | |
|             mWrapForRaster->buffer().type = inputTensor->buffer().type;
 | |
|             auto wrapDes                  = TensorUtils::getDescribe(mWrapForRaster.get());
 | |
|             wrapDes->memoryType           = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|             wrapDes->regions              = des->regions;
 | |
|             for (auto& r : wrapDes->regions) {
 | |
|                 r.origin = _getCopyTensor(r.origin);
 | |
|             }
 | |
|             mWrapInputTensors[i] = mWrapForRaster.get();
 | |
|         } else {
 | |
|             mWrapInputTensors[i] = _getCopyTensor(inputTensor);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     for (int i = 0; i < outputs.size(); ++i) {
 | |
|         MNN_ASSERT(TensorUtils::getDescribe(outputs[i])->backend == dstBackend);
 | |
|     }
 | |
|     bool memoryAllocSuccess = true;
 | |
|     // acquire memory, copy const tensors
 | |
|     for (auto& iter : mInputMaps) {
 | |
|         auto backend   = std::get<0>(iter.second);
 | |
|         auto converter = std::get<1>(iter.second);
 | |
|         auto src       = iter.first;
 | |
|         auto dst       = std::get<2>(iter.second).get();
 | |
| 
 | |
|         if (TensorUtils::getDescribe(src)->usage == TensorUsage::CONSTANT && mStatic) {
 | |
|             memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC_SEPERATE);
 | |
|             if (memoryAllocSuccess) {
 | |
|                 converter->onCopyBuffer(src, dst);
 | |
|                 TensorUtils::getDescribe(dst)->usage = TensorUtils::getDescribe(src)->usage;
 | |
|             }
 | |
|         } else {
 | |
|             memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC);
 | |
|         }
 | |
|     }
 | |
|     if (!memoryAllocSuccess) {
 | |
|         return OUT_OF_MEMORY;
 | |
|     }
 | |
| 
 | |
|     // do resize
 | |
|     auto result = mExecution->onResize(mWrapInputTensors, outputs);
 | |
| 
 | |
|     // release memory
 | |
|     for (auto& iter : mInputMaps) {
 | |
|         auto backend = std::get<0>(iter.second);
 | |
|         auto dst     = std::get<2>(iter.second).get();
 | |
| 
 | |
|         if (TensorUtils::getDescribe(dst)->usage == TensorUsage::CONSTANT && mStatic) {
 | |
|             backend->onReleaseBuffer(dst, Backend::DYNAMIC_SEPERATE);
 | |
|         } else {
 | |
|             backend->onReleaseBuffer(dst, Backend::DYNAMIC);
 | |
|         }
 | |
|     }
 | |
|     return result;
 | |
| }
 | |
| 
 | |
| ErrorCode WrapExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
 | |
|     MNN_ASSERT(mWrapInputTensors.size() == inputs.size());
 | |
| 
 | |
|     // copy variant tensors
 | |
|     for (auto& iter : mInputMaps) {
 | |
|         auto converter = std::get<1>(iter.second);
 | |
|         auto src       = iter.first;
 | |
|         auto dst       = std::get<2>(iter.second).get();
 | |
|         if (TensorUtils::getDescribe(src)->usage != TensorUsage::CONSTANT || (!mStatic)) {
 | |
|             converter->onCopyBuffer(src, dst);
 | |
|         }
 | |
|     }
 | |
|     auto code = mExecution->onExecute(mWrapInputTensors, outputs);
 | |
|     return code;
 | |
| }
 | |
| 
 | |
| CastWrapExecution::CastWrapExecution(const CPUBackend::Creator* creator, const Op* op, Backend* backend,
 | |
|                                      const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, halide_type_t runT)
 | |
|                                    : Execution(backend), runType(runT), mCreator(creator), mType(op->type()), mInputs(inputs) {
 | |
|     std::vector<int> types(inputs.size());
 | |
|     for (int i = 0; i < inputs.size(); i++) {
 | |
|         types[i] = TensorUtils::HaildeTypeToDataType(inputs[i]->getType());
 | |
|         inputs[i]->setType(TensorUtils::HaildeTypeToDataType(runType));
 | |
|     }
 | |
|     mExecution.reset(mCreator->onCreate(inputs, outputs, op, backend));
 | |
|     for (int i = 0; i < inputs.size(); i++) {
 | |
|         inputs[i]->setType(types[i]);
 | |
|     }
 | |
| }
 | |
| ErrorCode CastWrapExecution::onResize(const std::vector<Tensor*>& inputs,
 | |
|                                       const std::vector<Tensor*>& outputs) {
 | |
|     for (auto output : outputs) {
 | |
|         output->setType(TensorUtils::HaildeTypeToDataType(runType));
 | |
|     }
 | |
|     mWrapInputs.clear();
 | |
|     mCasts.clear();
 | |
|     mScales.clear();
 | |
|     auto& cachedCastTensor = static_cast<CPUBackend*>(backend())->getCachedCastTensor();
 | |
|     std::vector<Tensor*> realInput;
 | |
|     if (mType == OpType_Raster) {
 | |
|         for (const auto& r : TensorUtils::getDescribe(inputs[0])->regions) {
 | |
|             realInput.push_back(r.origin);
 | |
|         }
 | |
|     } else {
 | |
|         realInput = inputs;
 | |
|     }
 | |
|     for (int i = 0; i < realInput.size(); i++) {
 | |
|         auto input = realInput[i];
 | |
|         if (input->getType() == runType || !OpCommonUtils::opNeedContent(mType, i) || input->getType() == halide_type_of<int>()) {
 | |
|             mWrapInputs.push_back(input);
 | |
|             continue;
 | |
|         }
 | |
|         if (cachedCastTensor.find(input) != cachedCastTensor.end()) {
 | |
|             mWrapInputs.push_back(const_cast<Tensor*>(cachedCastTensor[input]));
 | |
|             continue;
 | |
|         }
 | |
|         std::unique_ptr<Tensor> wrapTensor(new Tensor);
 | |
|         TensorUtils::copyShape(input, wrapTensor.get(), true);
 | |
|         TensorUtils::setLinearLayout(wrapTensor.get());
 | |
|         TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(input)->quantAttr;
 | |
|         wrapTensor->buffer().type = runType;
 | |
|         bool memoryAllocSuccess = backend()->onAcquireBuffer(wrapTensor.get(), Backend::DYNAMIC);
 | |
|         if (!memoryAllocSuccess) {
 | |
|             return {};
 | |
|         }
 | |
|         mWrapInputs.push_back(wrapTensor.get());
 | |
|         auto wrapPointer = wrapTensor.get();
 | |
|         mCasts.insert(std::make_pair(input, wrapTensor.get()));
 | |
|         cachedCastTensor.insert(std::make_pair(input, wrapTensor.get()));
 | |
|         mWrapInputTensor.emplace_back(std::move(wrapTensor));
 | |
|         auto pack = static_cast<CPUBackend*>(backend())->functions()->pack;
 | |
|         mScales[input] = std::vector<float>(pack);
 | |
|         auto& quantAttr = TensorUtils::getDescribe(input)->quantAttr;
 | |
|         float scale = runType == halide_type_of<float>() ? quantAttr->scale : 1/quantAttr->scale;
 | |
|         // set 4xscale for SSE compute
 | |
|         mScales[input] = std::vector<float>(pack, scale);
 | |
|     }
 | |
|     ErrorCode res = NO_ERROR;
 | |
|     if (mType == OpType_Raster) {
 | |
|         mRasterInputTensor.reset(new Tensor(inputs[0], inputs[0]->getDimensionType(), false));
 | |
|         mRasterInput = mRasterInputTensor.get();
 | |
|         TensorUtils::getDescribe(mRasterInput)->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
 | |
|         TensorUtils::getDescribe(mRasterInput)->regions.resize(realInput.size());
 | |
|         for (int i = 0; i < realInput.size(); i++) {
 | |
|             TensorUtils::getDescribe(mRasterInput)->regions[i] = TensorUtils::getDescribe(inputs[0])->regions[i];
 | |
|             TensorUtils::getDescribe(mRasterInput)->regions[i].origin = mWrapInputs[i];
 | |
|         }
 | |
|         res = mExecution->onResize({mRasterInput}, outputs);
 | |
|     } else {
 | |
|         res = mExecution->onResize(mWrapInputs, outputs);
 | |
|     }
 | |
|     for (auto& iter : mCasts) {
 | |
|         if (TensorUtils::getDescribe(iter.first)->useCount <= 1) {
 | |
|             backend()->onReleaseBuffer(iter.second, Backend::DYNAMIC);
 | |
|         }
 | |
|     }
 | |
|     return res;
 | |
| }
 | |
| ErrorCode CastWrapExecution::onExecute(const std::vector<Tensor*>& inputs,
 | |
|                                        const std::vector<Tensor*>& outputs) {
 | |
|     for (const auto& iter : mCasts) {
 | |
|         auto input = iter.first;
 | |
|         auto output = iter.second;
 | |
|         auto& quantAttr = TensorUtils::getDescribe(input)->quantAttr;
 | |
|         MNN_ASSERT(quantAttr != nullptr);
 | |
|         auto cpuBackend = ((CPUBackend*)backend());
 | |
|         CPUCastCreator::cast(input, output, cpuBackend);
 | |
|     }
 | |
|     if (mType == OpType_Raster) {
 | |
|         return mExecution->onExecute({ mRasterInput }, outputs);
 | |
|     } else {
 | |
|         return mExecution->onExecute(mWrapInputs, outputs);
 | |
|     }
 | |
| }
 | |
| bool CastWrapExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
 | |
|     if (dst == nullptr || bn == nullptr) {
 | |
|         return true;
 | |
|     }
 | |
|     Execution* exe;
 | |
|     mExecution->onClone(bn, op, &exe);
 | |
|     *dst = new CastWrapExecution(bn, runType, op, exe);
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| CheckNANExecution::CheckNANExecution(Execution* exe) : Execution(exe->backend()) {
 | |
|     mExecution = exe;
 | |
|     mValid = exe->valid();
 | |
| }
 | |
| 
 | |
| CheckNANExecution::~CheckNANExecution() {
 | |
|     delete mExecution;
 | |
| }
 | |
| 
 | |
| ErrorCode CheckNANExecution::onResize(const std::vector<Tensor*>& inputs,
 | |
|                                       const std::vector<Tensor*>& outputs) {
 | |
|     return mExecution->onResize(inputs, outputs);
 | |
| }
 | |
| 
 | |
| ErrorCode CheckNANExecution::onExecute(const std::vector<Tensor*>& inputs,
 | |
|                                        const std::vector<Tensor*>& outputs) {
 | |
|     for (auto tensor : inputs) {
 | |
|         if (halide_type_float != tensor->getType().code) {
 | |
|             continue;
 | |
|         }
 | |
|         if (TensorUtils::getDescribe(tensor)->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
 | |
|             continue;
 | |
|         }
 | |
| #define MNN_IS_INF(x) (fabs(x) == INFINITY)
 | |
| #define MNN_IS_NAN(x) ((x) != (x))
 | |
|         auto size = tensor->elementSize();
 | |
|         auto ptr  = tensor->host<float>();
 | |
|         for (int i = 0; i < size; ++i) {
 | |
|             auto value = ptr[i];
 | |
|             if (MNN_IS_INF(value) || MNN_IS_NAN(value)) {
 | |
|                 return INVALID_VALUE;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     auto code = mExecution->onExecute(inputs, outputs);
 | |
|     if (NO_ERROR != code) {
 | |
|         return code;
 | |
|     }
 | |
|     for (auto tensor : outputs) {
 | |
|         if (halide_type_float != tensor->getType().code) {
 | |
|             continue;
 | |
|         }
 | |
|         auto size = tensor->elementSize();
 | |
|         auto ptr  = tensor->host<float>();
 | |
|         for (int i = 0; i < size; ++i) {
 | |
|             auto value = ptr[i];
 | |
|             if (MNN_IS_INF(value) || MNN_IS_NAN(value)) {
 | |
|                 return INVALID_VALUE;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     return NO_ERROR;
 | |
| }
 | |
| 
 | |
| } // namespace MNN
 |