MNN/source/backend/cpu/CPUBackend.cpp

726 lines
28 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// CPUBackend.cpp
// MNN
//
// Created by MNN on 2018/07/06.
// Copyright © 2018, Alibaba Group Holding Limited
//
2019-12-27 22:16:57 +08:00
#include "backend/cpu/CPUBackend.hpp"
#include <cmath>
2019-04-17 10:49:11 +08:00
#include <mutex>
2019-12-27 22:16:57 +08:00
#include "core/BufferAllocator.hpp"
2021-04-08 15:34:23 +08:00
#include "CPUTensorConvert.hpp"
2020-07-04 01:21:30 +08:00
#include "compute/CommonOptFunction.h"
2021-04-08 15:34:23 +08:00
#include "core/TensorUtils.hpp"
#include "ThreadPool.hpp"
#include "core/Concurrency.h"
#include "compute/Int8FunctionsOpt.h"
#include "CPUCast.hpp"
#include "core/OpCommonUtils.hpp"
2019-04-17 10:49:11 +08:00
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
2019-12-27 22:16:57 +08:00
#include "backend/cpu/CPURuntime.hpp"
#include "core/Macro.h"
#ifdef MNN_USE_ARMV82
2020-11-05 16:41:56 +08:00
#include "backend/arm82/Arm82Backend.hpp"
#endif
2019-04-17 10:49:11 +08:00
#define MAX_THREAD_NUMBER 32
2020-12-15 14:12:35 +08:00
#define LARGE_MEMORY 1024 * 1024 * 500
2021-04-08 15:34:23 +08:00
#ifdef MNN_SUPPORT_BF16
#include "bf16/BF16Backend.hpp"
#endif
2019-04-17 10:49:11 +08:00
#ifdef MNN_USE_SSE
#include "x86_x64/AVX2Backend.hpp"
#endif
#define MNN_CPU_CHECK_NAN 1
#define MNN_CPU_USE_DEFAULT_BACKEND 4
2019-04-17 10:49:11 +08:00
namespace MNN {
2019-05-09 19:39:33 +08:00
void registerCPUOps();
2019-04-17 10:49:11 +08:00
2020-11-05 16:41:56 +08:00
CPURuntime::CPURuntime(const Backend::Info& info) {
2020-12-15 14:12:35 +08:00
mStaticAllocator.reset(new BufferAllocator(BufferAllocator::Allocator::createDefault()));
2020-11-05 16:41:56 +08:00
mThreadNumber = info.numThread;
mThreadNumber = std::max(1, mThreadNumber);
mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER);
mPower = BackendConfig::Power_Normal;
mMemory = BackendConfig::Memory_Normal;
mPrecision = BackendConfig::Precision_Normal;
mFlops = MNNGetCPUFlops(mThreadNumber);
if (info.user != nullptr) {
mPrecision = info.user->precision;
mPower = info.user->power;
mMemory = info.user->memory;
mFlags = info.user->flags;
}
#ifdef _OPENMP
2020-11-05 16:41:56 +08:00
switch (mPower) {
2019-04-17 10:49:11 +08:00
case BackendConfig::Power_Low:
MNNSetCPUThreadsMode(MNN_CPU_MODE_LITTLE);
break;
case BackendConfig::Power_High:
MNNSetCPUThreadsMode(MNN_CPU_MODE_POWER_FRI);
break;
default:
break;
}
#endif
#ifdef MNN_USE_THREAD_POOL
mThreadNumber = ThreadPool::init(mThreadNumber);
if (mThreadNumber > 1) {
mTaskIndex = ThreadPool::acquireWorkIndex();
} else {
mTaskIndex = -1;
}
if (mTaskIndex >= 0 && mPower == BackendConfig::Power_High) {
ThreadPool::active();
}
#endif
2019-04-17 10:49:11 +08:00
}
2020-11-05 16:41:56 +08:00
CPURuntime:: ~ CPURuntime() {
#ifdef MNN_USE_THREAD_POOL
if (mTaskIndex >= 0 && mPower == BackendConfig::Power_High) {
ThreadPool::deactive();
}
ThreadPool::releaseWorkIndex(mTaskIndex);
#endif
2019-04-17 10:49:11 +08:00
}
2020-11-05 16:41:56 +08:00
float CPURuntime::onGetMemoryInMB() {
auto staticMemoryInMB = mStaticAllocator->totalSize() / 1024.0f / 1024.0f;
2020-12-15 14:12:35 +08:00
return staticMemoryInMB;
2020-11-05 16:41:56 +08:00
}
2021-04-08 15:34:23 +08:00
Backend* CPURuntime::onCreate(const BackendConfig* config) const {
auto precision = mPrecision;
size_t flags = mFlags;
2021-04-08 15:34:23 +08:00
if (nullptr != config) {
precision = config->precision;
flags = config->flags;
2021-04-08 15:34:23 +08:00
}
#ifdef MNN_USE_ARMV82
auto core = MNNGetCoreFunctions();
if (core->supportFp16arith && precision == BackendConfig::Precision_Low) {
2020-11-05 16:41:56 +08:00
return new Arm82Backend(this);
2019-04-17 10:49:11 +08:00
}
#endif
2021-04-08 15:34:23 +08:00
#ifdef MNN_SUPPORT_BF16
if (precision == BackendConfig::Precision_Low) {
return new BF16Backend(this);
}
#endif
if (flags == MNN_CPU_USE_DEFAULT_BACKEND) {
return new CPUBackend(this, precision, MNN_FORWARD_CPU, 0);
}
#ifdef MNN_USE_SSE
if (AVX2Backend::isValid()) {
return new AVX2Backend(this, flags);
}
#endif
return new CPUBackend(this, precision, MNN_FORWARD_CPU, flags);
2020-11-05 16:41:56 +08:00
}
void CPURuntime::onGabageCollect(int level) {
mStaticAllocator->release(false);
}
std::map<OpType, CPUBackend::Creator*>* CPUBackend::gCreator = nullptr;
void CPUBackend::initCreatorMap() {
gCreator = new std::map<OpType, CPUBackend::Creator*>;
}
bool CPUBackend::addCreator(OpType t, Creator* c) {
2021-04-08 15:34:23 +08:00
auto map = gCreator;
2020-11-05 16:41:56 +08:00
if (map->find(t) != map->end()) {
MNN_PRINT("Error: %d type has be added\n", t);
return false;
}
map->insert(std::make_pair(t, c));
return true;
}
CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, MNNForwardType type, size_t flags) : Backend(type) {
2020-11-05 16:41:56 +08:00
mRuntime = runtime;
mCheckNAN = flags == MNN_CPU_CHECK_NAN;
2020-12-15 14:12:35 +08:00
std::shared_ptr<BufferAllocator::Allocator> defaultAlloc(BufferAllocator::Allocator::createRecurse(runtime->mStaticAllocator.get()));
mDynamicAllocator.reset(new BufferAllocator(defaultAlloc));
2020-11-05 16:41:56 +08:00
mStaticAllocator = runtime->mStaticAllocator;
2021-04-08 15:34:23 +08:00
mPrecisionMode = precision;
mCoreFunctions = MNNGetCoreFunctions();
mInt8CoreFunctions = MNNGetInt8CoreFunctions();
2020-11-05 16:41:56 +08:00
}
CPUBackend::~CPUBackend() {
2020-12-15 14:12:35 +08:00
// Do nothing
2020-11-05 16:41:56 +08:00
}
void CPUBackend::onExecuteBegin() const {
#ifdef MNN_USE_THREAD_POOL
2020-11-05 16:41:56 +08:00
if (mRuntime->mTaskIndex >= 0 && mRuntime->mPower != BackendConfig::Power_High) {
ThreadPool::active();
}
#else
2019-04-17 10:49:11 +08:00
#ifdef _OPENMP
omp_set_dynamic(0);
2020-11-05 16:41:56 +08:00
omp_set_num_threads(threadNumber());
2019-04-17 10:49:11 +08:00
#endif
#endif
2019-04-17 10:49:11 +08:00
}
void CPUBackend::onExecuteEnd() const {
#ifdef MNN_USE_THREAD_POOL
2020-11-05 16:41:56 +08:00
if (mRuntime->mTaskIndex >= 0 && mRuntime->mPower != BackendConfig::Power_High) {
ThreadPool::deactive();
}
#endif
}
2019-04-17 10:49:11 +08:00
2020-12-15 14:12:35 +08:00
bool CPUBackend::allocBuffer(int size, Tensor* dest, StorageType storageType) {
2019-04-17 10:49:11 +08:00
// MNN_PRINT("Acquire size = %d\n", size);
if (size <= 0) {
2021-04-08 15:34:23 +08:00
MNN_PRINT("Acquire buffer size = %d\n", size);
2019-04-17 10:49:11 +08:00
MNN_ASSERT(false);
return false;
}
2021-04-08 15:34:23 +08:00
// if (size > LARGE_MEMORY) {
// MNN_PRINT("Size larger than 500 M :%d\n", size);
// }
2020-12-15 14:12:35 +08:00
auto& buffer = dest->buffer();
auto des = TensorUtils::getDescribe(dest);
std::pair<void*, int> points;
2019-04-17 10:49:11 +08:00
switch (storageType) {
case STATIC: {
2020-12-15 14:12:35 +08:00
points = mStaticAllocator->alloc(size, false);
2019-04-17 10:49:11 +08:00
break;
}
case DYNAMIC: {
2020-12-15 14:12:35 +08:00
points = mDynamicAllocator->alloc(size, false);
2019-04-17 10:49:11 +08:00
break;
}
case DYNAMIC_SEPERATE: {
2020-12-15 14:12:35 +08:00
points = mDynamicAllocator->alloc(size, true);
2019-04-17 10:49:11 +08:00
break;
}
default:
2020-11-05 16:41:56 +08:00
MNN_ASSERT(false);
2019-04-17 10:49:11 +08:00
break;
}
2020-12-15 14:12:35 +08:00
if (nullptr == points.first) {
2019-04-17 10:49:11 +08:00
MNN_ERROR("Alloc buffer error for cpu backend\n");
return false;
}
2020-12-15 14:12:35 +08:00
buffer.host = (uint8_t*)points.first + points.second;
des->extra.offset = points.second;
2019-04-17 10:49:11 +08:00
if (buffer.type.code == halide_type_handle) {
2020-12-15 14:12:35 +08:00
// For handle we needn't recycle the buffer, use extra as hanleFreeFunction
2019-04-17 10:49:11 +08:00
::memset(buffer.host, 0, size);
2020-12-15 14:12:35 +08:00
des->extra.handleFreeFunction = (decltype(des->extra.handleFreeFunction))free;
2019-04-17 10:49:11 +08:00
}
return true;
}
2020-11-05 16:41:56 +08:00
bool CPUBackend::onAcquireBuffer(const MNN::Tensor* nativeTensorConst, StorageType storageType) {
if (nativeTensorConst == nullptr) {
return false;
}
//FUNC_PRINT_ALL(nativeTensorConst, p);
auto nativeTensor = (Tensor*)nativeTensorConst;
auto size = nativeTensor->size();
2020-12-15 14:12:35 +08:00
return allocBuffer(size, nativeTensor, storageType);
2020-11-05 16:41:56 +08:00
}
2019-04-17 10:49:11 +08:00
bool CPUBackend::onReleaseBuffer(const MNN::Tensor* nativeTensor, StorageType storageType) {
2020-12-15 14:12:35 +08:00
if (DYNAMIC_SEPERATE == storageType) {
return true;
}
2019-12-27 22:16:57 +08:00
if (nativeTensor == nullptr) {
return false;
}
2019-04-17 10:49:11 +08:00
if (nullptr == nativeTensor->buffer().host) {
return false;
}
2020-12-15 14:12:35 +08:00
auto des = TensorUtils::getDescribe(nativeTensor);
std::pair<void*, int> pointer;
pointer.second = des->extra.offset;
pointer.first = (uint8_t*)nativeTensor->buffer().host - des->extra.offset;
2019-04-17 10:49:11 +08:00
if (STATIC == storageType) {
2020-12-15 14:12:35 +08:00
mStaticAllocator->free(pointer);
2019-04-17 10:49:11 +08:00
return true;
}
2020-12-15 14:12:35 +08:00
mDynamicAllocator->free(pointer);
2019-04-17 10:49:11 +08:00
return true;
}
2020-11-05 16:41:56 +08:00
std::pair<float, bool> CPUBackend::onMeasure(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op) {
2021-04-08 15:34:23 +08:00
auto map = gCreator;
auto iter = map->find(op->type());
if (iter == map->end()) {
MNN_PRINT("Don't support type %s, %s\n", MNN::EnumNameOpType(op->type()), op->name()->c_str());
return std::make_pair(0.0f, false);
}
2021-04-08 15:34:23 +08:00
// FIXME: Compute in future
2020-11-05 16:41:56 +08:00
return std::make_pair(0.0f, false);
2021-04-08 15:34:23 +08:00
}
halide_type_t CPUBackend::getRunType(const Op* op, halide_type_t qtype, halide_type_t rtype) {
auto otype = op->type();
switch (otype) {
case OpType_Convolution:
case OpType_ConvolutionDepthwise:
if (op->main_as_Convolution2D() && op->main_as_Convolution2D()->weight() != nullptr) {
return rtype;
} else {
return qtype;
}
case OpType_ConvInt8:
case OpType_DepthwiseConvInt8:
// case OpType_Eltwise:
2021-04-08 15:34:23 +08:00
case OpType_Raster:
return qtype;
case OpType_ReLU:
// now just relu without slope support quant
if ((op->main_as_Relu() == nullptr) || op->main_as_Relu()->slope() == 0.f) {
return qtype;
} else {
return rtype;
}
/*
case OpType_Pooling:
// now just maxpool support quant
if (op->main_as_Pool() && op->main_as_Pool()->type() == PoolType_MAXPOOL) {
return qtype;
} else {
return defaultType;
}
*/
default:
return rtype;
}
}
OpType CPUBackend::getRealOpType(OpType opType, halide_type_t dataType) {
// now just support int8
if (dataType != halide_type_of<int8_t>()) {
return opType;
}
switch (opType) {
case OpType_Convolution:
return OpType_ConvInt8;
case OpType_ConvolutionDepthwise:
return OpType_DepthwiseConvInt8;
/*
case OpType_Pooling:
return OpType_PoolInt8;
*/
// case OpType_Eltwise:
// // TODO: just support EltwiseAdd
// return OpType_EltwiseInt8;
2021-04-08 15:34:23 +08:00
default:
return opType;
}
}
int CPUBackend::getTensorSize(const Tensor* tensor) const {
auto core = mCoreFunctions;
int dataSize = 1;
auto des = TensorUtils::getDescribe(tensor);
for (int i = 0; i < tensor->dimensions(); i++) {
int currentDimSize = tensor->length(i);
if (des->dimensionFormat == MNN_DATA_FORMAT_NC4HW4 && 1 == i) {
currentDimSize = UP_DIV(currentDimSize, core->pack) * core->pack;
}
dataSize *= currentDimSize;
}
return dataSize;
}
2019-04-17 10:49:11 +08:00
/// get execution
Execution* CPUBackend::onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op) {
2020-11-05 16:41:56 +08:00
/**
BatchNorm it will be converted to scale
for model convert, don't print error log
*/
if (op->type() == OpType_BatchNorm) {
return nullptr;
}
2021-04-08 15:34:23 +08:00
// get QuantType and RunType, default is float
halide_type_t quantType = halide_type_of<float>();
auto isQuant = OpCommonUtils::getQuantInfo(inputs);
if (isQuant.first) {
// if output hasnt scale, using output type
if (TensorUtils::getDescribe(outputs[0])->quantAttr == nullptr && !outputs.empty()) {
quantType = outputs[0]->getType();
} else {
quantType = TensorUtils::DataTypeToHalideType(isQuant.second);
}
}
auto originType = outputs.empty() ? halide_type_of<float>() : outputs[0]->getType();
auto runType = getRunType(op, quantType, originType);
// TODO: rm this convert when merge diff datatyoe of op
auto opType = op->type();
if (isQuant.first) {
opType = getRealOpType(opType, runType);
}
auto map = gCreator;
auto iter = map->find(opType);
2019-04-17 10:49:11 +08:00
if (iter == map->end()) {
MNN_PRINT("Don't support type [%s], %s\n", MNN::EnumNameOpType(op->type()), op->name()->c_str());
2019-04-17 10:49:11 +08:00
return nullptr;
}
2021-04-08 15:34:23 +08:00
Execution* exe = nullptr;
if (isQuant.first) {
bool needCast = false;
// judge is it need CastWrap
if (OpType_Raster == opType) {
inputs[0]->setType(TensorUtils::HaildeTypeToDataType(runType));
for (const auto& r : TensorUtils::getDescribe(inputs[0])->regions) {
needCast |= (r.origin->getType() != runType);
}
} else {
for (int i = 0; i < inputs.size(); i++) {
if (OpCommonUtils::opNeedContent(opType, i) && inputs[i]->getType() != halide_type_of<int>()) {
needCast |= (inputs[i]->getType() != runType);
}
}
}
// set output Tensor Type
auto outputType = TensorUtils::HaildeTypeToDataType(runType);
for (auto output : outputs) {
if (output->getType() != runType) {
output->setType(outputType);
needCast = true;
}
}
if (needCast) {
class CastWrapExecution : public Execution {
public:
CastWrapExecution(Backend* backend, halide_type_t runT, const Op* op, Execution* exe)
: Execution(backend), runType(runT), mType(op->type()), mExecution(exe) {}
2021-04-08 15:34:23 +08:00
CastWrapExecution(const CPUBackend::Creator* creator, const Op* op, Backend* backend,
const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, halide_type_t runT)
: Execution(backend), runType(runT), mCreator(creator), mType(op->type()), mInputs(inputs) {
2021-04-08 15:34:23 +08:00
std::vector<int> types(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
types[i] = TensorUtils::HaildeTypeToDataType(inputs[i]->getType());
inputs[i]->setType(TensorUtils::HaildeTypeToDataType(runType));
}
mExecution.reset(mCreator->onCreate(inputs, outputs, op, backend));
2021-04-08 15:34:23 +08:00
for (int i = 0; i < inputs.size(); i++) {
inputs[i]->setType(types[i]);
}
}
virtual ErrorCode onResize(const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) override {
for (auto output : outputs) {
output->setType(TensorUtils::HaildeTypeToDataType(runType));
}
mWrapInputs.clear();
mCasts.clear();
mScales.clear();
auto& cachedCastTensor = static_cast<CPUBackend*>(backend())->getCachedCastTensor();
2021-04-08 15:34:23 +08:00
std::vector<Tensor*> realInput;
if (mType == OpType_Raster) {
2021-04-08 15:34:23 +08:00
for (const auto& r : TensorUtils::getDescribe(inputs[0])->regions) {
realInput.push_back(r.origin);
}
} else {
realInput = inputs;
}
for (int i = 0; i < realInput.size(); i++) {
auto input = realInput[i];
if (input->getType() == runType || !OpCommonUtils::opNeedContent(mType, i) || input->getType() == halide_type_of<int>()) {
2021-04-08 15:34:23 +08:00
mWrapInputs.push_back(input);
continue;
}
if (cachedCastTensor.find(input) != cachedCastTensor.end()) {
mWrapInputs.push_back(const_cast<Tensor*>(cachedCastTensor[input].get()));
2021-04-08 15:34:23 +08:00
continue;
}
std::unique_ptr<Tensor> wrapTensor(new Tensor);
TensorUtils::copyShape(input, wrapTensor.get(), true);
TensorUtils::setLinearLayout(wrapTensor.get());
2021-04-08 15:34:23 +08:00
TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(input)->quantAttr;
wrapTensor->buffer().type = runType;
bool memoryAllocSuccess = backend()->onAcquireBuffer(wrapTensor.get(), Backend::DYNAMIC);
if (!memoryAllocSuccess) {
return {};
}
mWrapInputs.push_back(wrapTensor.get());
auto wrapPointer = wrapTensor.get();
mCasts.insert(std::make_pair(input, wrapTensor.get()));
cachedCastTensor.insert(std::make_pair(input, std::move(wrapTensor)));
2021-04-08 15:34:23 +08:00
mScales[input] = std::vector<float>(4);
auto& quantAttr = TensorUtils::getDescribe(input)->quantAttr;
float scale = runType == halide_type_of<float>() ? quantAttr->scale : 1/quantAttr->scale;
// set 4xscale for SSE compute
mScales[input][0] = scale;
mScales[input][1] = scale;
mScales[input][2] = scale;
mScales[input][3] = scale;
}
ErrorCode res = NO_ERROR;
if (mType == OpType_Raster) {
mRasterInputTensor.reset(new Tensor(inputs[0], inputs[0]->getDimensionType(), false));
mRasterInput = mRasterInputTensor.get();
TensorUtils::getDescribe(mRasterInput)->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
TensorUtils::getDescribe(mRasterInput)->regions.resize(realInput.size());
for (int i = 0; i < realInput.size(); i++) {
TensorUtils::getDescribe(mRasterInput)->regions[i] = TensorUtils::getDescribe(inputs[0])->regions[i];
TensorUtils::getDescribe(mRasterInput)->regions[i].origin = mWrapInputs[i];
2021-04-08 15:34:23 +08:00
}
res = mExecution->onResize({mRasterInput}, outputs);
} else {
res = mExecution->onResize(mWrapInputs, outputs);
}
for (auto& iter : mCasts) {
if (TensorUtils::getDescribe(iter.first)->useCount <= 1) {
backend()->onReleaseBuffer(iter.second, Backend::DYNAMIC);
}
}
return res;
}
virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) override {
for (const auto& iter : mCasts) {
auto input = iter.first;
auto output = iter.second;
auto& quantAttr = TensorUtils::getDescribe(input)->quantAttr;
MNN_ASSERT(quantAttr != nullptr);
auto numberThread = ((CPUBackend*)backend())->threadNumber();
if (numberThread == 1) {
CPUCastCreator::cast(input, output);
continue;
}
int size = input->elementSize();
int sizeQuad = size / 16;
int remain = sizeQuad * 16;
int sizeDivide = sizeQuad / numberThread;
auto scale = mScales[input].data();
if (runType == halide_type_of<float>()) {
const auto inputDataPtr = input->host<int8_t>();
auto outputDataPtr = output->host<float>();
if (sizeQuad > 0) {
MNN_CONCURRENCY_BEGIN(tId, numberThread) {
int number = sizeDivide;
if (tId == numberThread - 1) {
number = sizeQuad - tId * sizeDivide;
}
const auto srcChannelPtr = inputDataPtr + tId * sizeDivide * 16;
auto dstChannlePtr = outputDataPtr + tId * sizeDivide * 16;
MNNInt8ScaleToFloat(dstChannlePtr, srcChannelPtr, scale, sizeDivide * 4, quantAttr->zero);
}
MNN_CONCURRENCY_END();
}
for (int i = remain; i < size; i++) {
outputDataPtr[i] = (inputDataPtr[i] - quantAttr->zero) * scale[0];
2021-04-08 15:34:23 +08:00
}
} else {
const auto inputDataPtr = input->host<float>();
auto outputDataPtr = output->host<int8_t>();
if (sizeQuad > 0) {
MNN_CONCURRENCY_BEGIN(tId, numberThread) {
int number = sizeDivide;
if (tId == numberThread - 1) {
number = sizeQuad - tId * sizeDivide;
}
const auto srcChannelPtr = inputDataPtr + tId * sizeDivide * 16;
auto dstChannlePtr = outputDataPtr + tId * sizeDivide * 16;
MNNFloat2Int8(srcChannelPtr, dstChannlePtr, sizeDivide * 4, scale, quantAttr->min, quantAttr->max, quantAttr->zero);
}
MNN_CONCURRENCY_END();
}
for (int i = remain; i < size; i++) {
float value = std::round(inputDataPtr[i] * scale[0] + quantAttr->zero);
outputDataPtr[i] = static_cast<int8_t>(std::min(std::max(value, quantAttr->min), quantAttr->max));
2021-04-08 15:34:23 +08:00
}
}
}
if (mType == OpType_Raster) {
2021-04-08 15:34:23 +08:00
return mExecution->onExecute({ mRasterInput }, outputs);
} else {
return mExecution->onExecute(mWrapInputs, outputs);
}
}
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override {
if (dst == nullptr || bn == nullptr) {
return true;
}
Execution* exe;
mExecution->onClone(bn, op, &exe);
*dst = new CastWrapExecution(bn, runType, op, exe);
2021-04-08 15:34:23 +08:00
return true;
};
private:
MNN::OpType mType;
2021-04-08 15:34:23 +08:00
const CPUBackend::Creator* mCreator;
halide_type_t runType;
std::shared_ptr<Execution> mExecution;
Tensor* mRasterInput;
std::vector<Tensor*> mWrapInputs, mInputs;
std::unique_ptr<Tensor> mRasterInputTensor;
std::map<const Tensor*, const Tensor*> mCasts;
2021-04-08 15:34:23 +08:00
std::map<const Tensor*, std::vector<float>> mScales;
bool firstResize = true;
};
exe = new CastWrapExecution(iter->second, op, this, inputs, outputs, runType);
2021-04-08 15:34:23 +08:00
}
}
if (exe == nullptr) {
exe = iter->second->onCreate(inputs, outputs, op, this);
}
2019-04-17 10:49:11 +08:00
if (nullptr == exe) {
return nullptr;
}
return makePostWrapExectuion(exe);
}
Execution* CPUBackend::makePostWrapExectuion(Execution* execution) const {
if (!mCheckNAN) {
return execution;
}
class CheckNANExecution : public Execution {
public:
CheckNANExecution(Execution* exe) : Execution(exe->backend()) {
mExecution = exe;
mValid = exe->valid();
}
virtual ~CheckNANExecution() {
delete mExecution;
}
virtual ErrorCode onResize(const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) override {
return mExecution->onResize(inputs, outputs);
}
virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) override {
for (auto tensor : inputs) {
if (halide_type_float != tensor->getType().code) {
return NO_ERROR;
}
if (TensorUtils::getDescribe(tensor)->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
return NO_ERROR;
}
#define MNN_IS_INF(x) (fabs(x) == INFINITY)
#define MNN_IS_NAN(x) ((x) != (x))
auto size = tensor->elementSize();
auto ptr = tensor->host<float>();
for (int i = 0; i < size; ++i) {
auto value = ptr[i];
if (MNN_IS_INF(value) || MNN_IS_NAN(value)) {
return INVALID_VALUE;
}
}
}
auto code = mExecution->onExecute(inputs, outputs);
if (NO_ERROR != code) {
return code;
}
for (auto tensor : outputs) {
if (halide_type_float != tensor->getType().code) {
return NO_ERROR;
}
auto size = tensor->elementSize();
auto ptr = tensor->host<float>();
for (int i = 0; i < size; ++i) {
auto value = ptr[i];
if (MNN_IS_INF(value) || MNN_IS_NAN(value)) {
return INVALID_VALUE;
}
}
}
return NO_ERROR;
}
private:
Execution* mExecution;
};
return new CheckNANExecution(execution);
2019-04-17 10:49:11 +08:00
}
2019-04-17 10:49:11 +08:00
bool CPUBackend::onClearBuffer() {
2020-12-15 14:12:35 +08:00
mDynamicAllocator->release(true);
2021-04-08 15:34:23 +08:00
mCachedCastTensor.clear();
2019-04-17 10:49:11 +08:00
return true;
}
2020-11-05 16:41:56 +08:00
2020-02-26 09:57:17 +08:00
std::pair<int, int> CPUBackend::multiThreadDivide(int size) const {
2020-11-05 16:41:56 +08:00
int sizeDivide = size / threadNumber();
2021-04-08 15:34:23 +08:00
sizeDivide = UP_DIV(sizeDivide, mCoreFunctions->pack) * mCoreFunctions->pack;
2020-02-26 09:57:17 +08:00
int scheduleNumber = 1;
if (sizeDivide > 0) {
scheduleNumber = UP_DIV(size, sizeDivide);
}
return std::make_pair(sizeDivide, scheduleNumber);
}
2019-04-17 10:49:11 +08:00
void CPUBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const {
auto& srcBuffer = srcTensor->buffer();
auto& dstBuffer = dstTensor->buffer();
MNN_ASSERT(srcBuffer.dimensions == dstBuffer.dimensions);
if (srcTensor->getDimensionType() == dstTensor->getDimensionType()) {
for (int i = 0; i < srcBuffer.dimensions; ++i) {
MNN_ASSERT(srcBuffer.dim[i].extent <= dstBuffer.dim[i].extent);
}
}
if (nullptr == srcBuffer.host || nullptr == dstBuffer.host) {
return;
}
2021-04-08 15:34:23 +08:00
if (srcBuffer.type != dstBuffer.type) {
ErrorCode code = NO_ERROR;
if (TensorUtils::getDescribe(srcTensor)->dimensionFormat != TensorUtils::getDescribe(dstTensor)->dimensionFormat) {
std::unique_ptr<Tensor> wrapTensor;
auto dimType = Tensor::CAFFE;
switch (TensorUtils::getDescribe(srcTensor)->dimensionFormat) {
case MNN_DATA_FORMAT_NCHW:
break;
case MNN_DATA_FORMAT_NC4HW4:
dimType = Tensor::CAFFE_C4;
break;
case MNN_DATA_FORMAT_NHWC:
dimType = Tensor::TENSORFLOW;
break;
default:
break;
}
wrapTensor.reset(Tensor::create(srcTensor->shape(), dstTensor->getType(), nullptr, dimType));
code = CPUCastCreator::cast(srcTensor, wrapTensor.get());
CPUTensorConverter::convert(wrapTensor.get(), dstTensor);
} else {
code = CPUCastCreator::cast(srcTensor, dstTensor);
}
2021-04-08 15:34:23 +08:00
if (NO_ERROR != code) {
MNN_ERROR("Error in CPUBackend::onCopyBuffer:cast\n");
}
return;
2021-04-08 15:34:23 +08:00
}
2019-12-27 22:16:57 +08:00
auto code = CPUTensorConverter::convert(srcTensor, dstTensor);
if (NO_ERROR != code) {
2021-04-08 15:34:23 +08:00
MNN_ERROR("Error in CPUBackend::onCopyBuffer:convert\n");
2019-12-27 22:16:57 +08:00
}
2019-04-17 10:49:11 +08:00
}
2020-11-05 16:41:56 +08:00
class CPURuntimeCreator : public RuntimeCreator {
public:
virtual Runtime* onCreate(const Backend::Info& info) const override {
return new CPURuntime(info);
2019-04-17 10:49:11 +08:00
}
};
2020-11-05 16:41:56 +08:00
2021-04-08 15:34:23 +08:00
#ifdef MNN_SUPPORT_BF16
extern void registerBF16Backend();
#endif
2020-11-05 16:41:56 +08:00
void registerCPURuntimeCreator() {
CPUBackend::initCreatorMap();
registerCPUOps();
2021-04-08 15:34:23 +08:00
#ifdef MNN_SUPPORT_BF16
registerBF16Backend();
#endif
// TODO: Merge _initCoreFunction MNNFunctionInit and cpuinfo_arm_init
MNNCoreFunctionInit();
MNNCoreInt8FunctionInit();
2020-11-05 16:41:56 +08:00
MNNInsertExtraRuntimeCreator(MNN_FORWARD_CPU, new CPURuntimeCreator);
};
2019-04-17 10:49:11 +08:00
} // namespace MNN