mirror of https://github.com/alibaba/MNN.git
[MNN:Sync] Sync internal Gitlab
This commit is contained in:
parent
7af23d29f4
commit
d91fc63976
|
@ -147,7 +147,7 @@ CTestTestfile.cmake
|
|||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.py[od]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
|
|
|
@ -96,8 +96,8 @@ IF(WIN32)
|
|||
SET(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /Zi")
|
||||
SET(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Zi")
|
||||
|
||||
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4267 /wd4018 /wd4251 /wd4996 /wd4244 /wd4146 /wd4129 /wd4305 /wd4275")
|
||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4018 /wd4251 /wd4996 /wd4244 /wd4146 /wd4129 /wd4305 /wd4275")
|
||||
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4267 /wd4018 /wd4251 /wd4996 /wd4244 /wd4146 /wd4129 /wd4305 /wd4275 /wd4819")
|
||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4018 /wd4251 /wd4996 /wd4244 /wd4146 /wd4129 /wd4305 /wd4275 /wd4819")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
|
@ -124,9 +124,6 @@ endif()
|
|||
if(MNN_SUPPORT_TFLITE_QUAN)
|
||||
add_definitions(-DMNN_SUPPORT_TFLITE_QUAN)
|
||||
endif()
|
||||
if(MNN_BUILD_MINI)
|
||||
add_definitions(-DMNN_BUILD_MINI)
|
||||
endif()
|
||||
|
||||
# debug options
|
||||
if(MNN_DEBUG_MEMORY)
|
||||
|
@ -156,6 +153,12 @@ if (MNN_USE_THREAD_POOL)
|
|||
add_definitions(-DMNN_USE_THREAD_POOL)
|
||||
endif()
|
||||
|
||||
# When build Android based on arm32 by MTL, force turn off MNN_ARM82
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "^Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND)
|
||||
message(STATUS "force turn off MNN_ARM82 when build for Android based on arm32 by MTL")
|
||||
SET(MNN_ARM82 OFF CACHE BOOL "Enable ARM82" FORCE)
|
||||
endif()
|
||||
|
||||
# target options
|
||||
option(MNN_BUILD_BENCHMARK "Build benchmark or not" OFF)
|
||||
option(MNN_BUILD_TEST "Build tests or not" OFF)
|
||||
|
@ -181,6 +184,7 @@ message(STATUS "\toneDNN: ${MNN_ONEDNN}")
|
|||
message(STATUS "\tTensorRT: ${MNN_TENSORRT}")
|
||||
message(STATUS "\tCUDA: ${MNN_CUDA}")
|
||||
message(STATUS "\tOpenMP: ${MNN_OPENMP}")
|
||||
message(STATUS "\tBF16: ${MNN_SUPPORT_BF16}")
|
||||
message(STATUS "\tThreadPool: ${MNN_USE_THREAD_POOL}")
|
||||
message(STATUS "\tHidden: ${MNN_HIDDEN}")
|
||||
message(STATUS "\tBuild Path: ${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
@ -306,6 +310,9 @@ FILE(GLOB MNN_Core_SRC ${CMAKE_CURRENT_LIST_DIR}/source/core/*)
|
|||
add_library(MNNCore OBJECT ${MNN_Core_SRC})
|
||||
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNCore>)
|
||||
list(APPEND MNN_TARGETS MNNCore)
|
||||
if(MNN_BUILD_MINI)
|
||||
target_compile_options(MNNCore PRIVATE -DMNN_BUILD_MINI)
|
||||
endif()
|
||||
|
||||
# CV
|
||||
FILE(GLOB MNN_CV_SRC ${CMAKE_CURRENT_LIST_DIR}/source/cv/*)
|
||||
|
@ -340,23 +347,8 @@ add_library(MNNUtils OBJECT ${MNN_Utils_SRC})
|
|||
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNUtils>)
|
||||
list(APPEND MNN_TARGETS MNNUtils)
|
||||
|
||||
# CPU
|
||||
FILE(GLOB MNN_CPU_SRC ${CMAKE_CURRENT_LIST_DIR}/source/backend/cpu/* ${CMAKE_CURRENT_LIST_DIR}/source/backend/cpu/compute/*)
|
||||
add_library(MNNCPU OBJECT ${MNN_CPU_SRC})
|
||||
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNCPU>)
|
||||
list(APPEND MNN_TARGETS MNNCPU)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/source/backend/cpu/CMakeLists.txt)
|
||||
|
||||
# X86_64 AVX/SSE
|
||||
if (MNN_USE_SSE)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/source/backend/cpu/x86_x64/CMakeLists.txt)
|
||||
endif()
|
||||
|
||||
# AArch32/64 Assemblies
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/source/backend/cpu/arm/CMakeLists.txt)
|
||||
|
||||
IF(NOT DEFINED IOS_ARCH)
|
||||
set(IOS_ARCH "")
|
||||
ENDIF()
|
||||
|
||||
SET(MNN_PUB_HDRS "")
|
||||
SET(MNN_EXPR_PUB_HDRS "")
|
||||
|
@ -513,16 +505,6 @@ IF(MNN_CUDA)
|
|||
list(APPEND MNN_EXTRA_DEPENDS ${MNN_CUDA_LIBS})
|
||||
ENDIF()
|
||||
|
||||
IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR IOS_ARCH STREQUAL "arm64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
|
||||
# ARM82 Assemblies
|
||||
IF(MNN_ARM82)
|
||||
add_definitions(-DENABLE_ARMV82)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/source/backend/arm82/)
|
||||
list(APPEND MNN_TARGETS MNN_Arm82)
|
||||
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_Arm82>)
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
# Express
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/express/)
|
||||
IF(MNN_SEP_BUILD)
|
||||
|
|
|
@ -81,7 +81,7 @@ function bench_android() {
|
|||
#benchmark OpenGL
|
||||
#adb shell "LD_LIBRARY_PATH=$ANDROID_DIR $ANDROID_DIR/benchmark.out $ANDROID_DIR/benchmark_models $RUN_LOOP 5 6 2>$ANDROID_DIR/benchmark.err >> $ANDROID_DIR/benchmark.txt"
|
||||
#benchmark OpenCL
|
||||
#adb shell "LD_LIBRARY_PATH=$ANDROID_DIR $ANDROID_DIR/benchmark.out $ANDROID_DIR/benchmark_models $RUN_LOOP 5 3 2>$ANDROID_DIR/benchmark.err >> $ANDROID_DIR/benchmark.txt"
|
||||
#adb shell "LD_LIBRARY_PATH=$ANDROID_DIR $ANDROID_DIR/benchmark.out $ANDROID_DIR/benchmark_models 100 20 3 2>$ANDROID_DIR/benchmark.err >> $ANDROID_DIR/benchmark.txt"
|
||||
adb pull $ANDROID_DIR/benchmark.txt ../
|
||||
}
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ std::vector<Model> findModelFiles(const char* dir) {
|
|||
#if defined(_MSC_VER)
|
||||
WIN32_FIND_DATA ffd;
|
||||
HANDLE hFind = INVALID_HANDLE_VALUE;
|
||||
std::string mnn_model_pattern = std::string(dir) + "\\*.mnn";
|
||||
std::string mnn_model_pattern = std::string(dir) + "\\*.mnn";
|
||||
hFind = FindFirstFile(mnn_model_pattern.c_str(), &ffd);
|
||||
if (INVALID_HANDLE_VALUE == hFind) {
|
||||
std::cout << "open " << dir << " failed: " << strerror(errno) << std::endl;
|
||||
|
@ -178,7 +178,7 @@ void displayStats(const std::string& name, const std::vector<float>& costs) {
|
|||
//printf("[ - ] cost:%f ms\n", v);
|
||||
}
|
||||
avg = costs.size() > 0 ? sum / costs.size() : 0;
|
||||
printf("[ - ] %-24s max = %8.3fms min = %8.3fms avg = %8.3fms\n", name.c_str(), max, avg == 0 ? 0 : min, avg);
|
||||
printf("[ - ] %-24s max = %8.3f ms min = %8.3f ms avg = %8.3f ms\n", name.c_str(), max, avg == 0 ? 0 : min, avg);
|
||||
}
|
||||
static inline std::string forwardType(MNNForwardType type) {
|
||||
switch (type) {
|
||||
|
@ -318,7 +318,7 @@ void set_cpu_affinity()
|
|||
int cpu_id = 0;
|
||||
cpu_set_t mask;
|
||||
CPU_ZERO(&mask);
|
||||
|
||||
|
||||
auto numberOfCPUs = getNumberOfCPU();
|
||||
static std::vector<int> sortedCPUIDs;
|
||||
static int littleClusterOffset = 0;
|
||||
|
@ -379,10 +379,10 @@ int main(int argc, const char* argv[]) {
|
|||
std::vector<Model> models = findModelFiles(argv[1]);
|
||||
|
||||
std::cout << "--------> Benchmarking... loop = " << argv[2] << ", warmup = " << warmup << std::endl;
|
||||
|
||||
|
||||
/* not called yet */
|
||||
// set_cpu_affinity();
|
||||
|
||||
|
||||
for (auto& m : models) {
|
||||
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision);
|
||||
displayStats(m.name, costs);
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -6,7 +6,9 @@ option(MNN_CODEGEN_JIT "Build jit for codegen." OFF)
|
|||
|
||||
file(GLOB CODEGEN_HEADER "${CMAKE_CURRENT_LIST_DIR}/*.*")
|
||||
file(GLOB CPU_SRCS "${CMAKE_CURRENT_LIST_DIR}/cpu/*.*")
|
||||
file(GLOB JIT_SRCS "${CMAKE_CURRENT_LIST_DIR}/jit/*.*")
|
||||
list(APPEND MNN_CODEGEN_SRCS ${CODEGEN_HEADER})
|
||||
list(APPEND MNN_CODEGEN_SRCS ${JIT_SRCS})
|
||||
|
||||
if(MNN_CODEGEN_OPENCL)
|
||||
add_definitions(-DMNN_CODEGEN_OPENCL)
|
||||
|
@ -34,7 +36,7 @@ if(MNN_CODEGEN_LLVM)
|
|||
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
llvm_map_components_to_libnames(llvm_libs core bitwriter)
|
||||
llvm_map_components_to_libnames(llvm_libs core bitwriter OrcJIT Support nativecodegen native CodeGen)
|
||||
list(APPEND MNN_EXTRA_DEPENDS ${llvm_libs})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -9,9 +9,11 @@
|
|||
#include "OpFuse.hpp"
|
||||
#include "geometry/GeometryComputerUtils.hpp"
|
||||
#include "PluginModule.hpp"
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include "cpu/CPUAst.hpp"
|
||||
#include "jit/LLVMJit.hpp"
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
@ -73,6 +75,7 @@ bool isLegal(const Command* cmd) {
|
|||
if (elemWise) {
|
||||
return true;
|
||||
}
|
||||
#define fuse_raster
|
||||
#ifdef fuse_raster
|
||||
if (type == OpType_Raster) {
|
||||
auto outputFormat = TensorUtils::getDescribe(cmd->outputs[0])->dimensionFormat;
|
||||
|
@ -134,6 +137,136 @@ std::vector<Node*> fuseNode(Node* root, std::vector<Node*>& edges) {
|
|||
}
|
||||
return fuseSet;
|
||||
}
|
||||
|
||||
void codegen(CommandBuffer& cmd, std::vector<std::vector<Node*>>& fuseSets) {
|
||||
// generate Kernel
|
||||
CPUPluginModule plugin("codegen_demo");
|
||||
for (auto compSet : fuseSets) {
|
||||
// printf("set size: %lu \n", compSet.size());
|
||||
InOutTensors tensors = plugin.addFunction(compSet);
|
||||
auto inputs = tensors.first;
|
||||
auto outputs = tensors.second;
|
||||
// build Plugin Op
|
||||
Command cmdPlugin;
|
||||
{
|
||||
std::unique_ptr<OpT> pluginOp(new OpT);
|
||||
pluginOp->type = OpType_Plugin;
|
||||
pluginOp->name = "PluginWrapper";
|
||||
PluginT* plugin_param = new PluginT;
|
||||
plugin_param->type = "PluginWrapper";
|
||||
plugin_param->attr.resize(1);
|
||||
plugin_param->attr[0].reset(new AttributeT);
|
||||
plugin_param->attr[0]->key = "kernel";
|
||||
plugin_param->attr[0]->i = plugin.getFunctionNum()-1;
|
||||
pluginOp->main.type = OpParameter_Plugin;
|
||||
pluginOp->main.value = plugin_param;
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto lastOffset = Op::Pack(builder, pluginOp.get());
|
||||
builder.Finish(lastOffset);
|
||||
cmdPlugin = GeometryComputerUtils::makeCommand(builder, inputs, outputs);
|
||||
}
|
||||
for (int i = 0; i < compSet.size(); i++) {
|
||||
auto cmd = const_cast<Command*>(compSet[i]->cmd);
|
||||
if (i == compSet.size()-1) {
|
||||
cmd->op = cmdPlugin.op;
|
||||
cmd->inputs = cmdPlugin.inputs;
|
||||
cmd->outputs = cmdPlugin.outputs;
|
||||
cmd->buffer = cmdPlugin.buffer;
|
||||
} else {
|
||||
cmd->op = nullptr;
|
||||
cmd->buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
// printf("total: %d\n", idx);
|
||||
plugin.codegen();
|
||||
// printf("cmd num: %lu \n", cmd.command.size());
|
||||
for (auto iter = cmd.command.begin(); iter != cmd.command.end();) {
|
||||
if (iter->op == nullptr) {
|
||||
iter = cmd.command.erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
#if !defined(_MSC_VER)
|
||||
// printf("cmd num: %lu \n", cmd.command.size());
|
||||
dlopen("./libplugin_fuse.so", RTLD_NOW | RTLD_LOCAL);
|
||||
#endif
|
||||
}
|
||||
|
||||
void jit(CommandBuffer& cmd, std::vector<std::vector<Node*>>& fuseSets) {
|
||||
LLVMJIT* theJit = LLVMJIT::createLLVMJIT();
|
||||
CPUPluginModule plugin("jit_demo");
|
||||
std::string kernelStr;
|
||||
for (auto compSet : fuseSets) {
|
||||
/*
|
||||
// printf("set size: %lu \n", compSet.size());
|
||||
if (true) {
|
||||
for (auto com : compSet) {
|
||||
// json :
|
||||
// { fusedOps: [ { idx:int, srcOps: [name: string], inputs:[name:string], outputs:[name:string] } ], dynlib:string, jitObj:string, module:string }
|
||||
dumpCmd(com->cmd);
|
||||
}
|
||||
}
|
||||
*/
|
||||
kernelStr += "[";
|
||||
for (auto com : compSet) {
|
||||
kernelStr += com->cmd->op->name()->str();
|
||||
}
|
||||
kernelStr += "]";
|
||||
InOutTensors tensors = plugin.addFunction(compSet);
|
||||
auto inputs = tensors.first;
|
||||
auto outputs = tensors.second;
|
||||
// build Plugin Op
|
||||
Command cmdPlugin;
|
||||
{
|
||||
std::unique_ptr<OpT> pluginOp(new OpT);
|
||||
pluginOp->type = OpType_Plugin;
|
||||
pluginOp->name = "JitPluginWrapper";
|
||||
PluginT* plugin_param = new PluginT;
|
||||
plugin_param->type = "JitPluginWrapper";
|
||||
plugin_param->attr.resize(1);
|
||||
plugin_param->attr[0].reset(new AttributeT);
|
||||
plugin_param->attr[0]->key = "kernel";
|
||||
plugin_param->attr[0]->i = plugin.getFunctionNum() - 1;
|
||||
pluginOp->main.type = OpParameter_Plugin;
|
||||
pluginOp->main.value = plugin_param;
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto lastOffset = Op::Pack(builder, pluginOp.get());
|
||||
builder.Finish(lastOffset);
|
||||
cmdPlugin = GeometryComputerUtils::makeCommand(builder, inputs, outputs);
|
||||
}
|
||||
for (int i = 0; i < compSet.size(); i++) {
|
||||
auto cmd = const_cast<Command*>(compSet[i]->cmd);
|
||||
if (i == compSet.size()-1) {
|
||||
cmd->op = cmdPlugin.op;
|
||||
cmd->inputs = cmdPlugin.inputs;
|
||||
cmd->outputs = cmdPlugin.outputs;
|
||||
cmd->buffer = cmdPlugin.buffer;
|
||||
} else {
|
||||
cmd->op = nullptr;
|
||||
cmd->buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto iter = cmd.command.begin(); iter != cmd.command.end();) {
|
||||
if (iter->op == nullptr) {
|
||||
iter = cmd.command.erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
size_t id = std::hash<std::string>()(kernelStr);
|
||||
std::unique_ptr<LLVMTarget> target(new LLVMTarget("jit-kenerl-" + std::to_string(id)));
|
||||
target->getModule()->setDataLayout(theJit->getDataLayout());
|
||||
plugin.codegen(target.get());
|
||||
// add module to JIT and compile
|
||||
auto m = target->getThreadSafeModule();
|
||||
auto resourceTracker = theJit->getMainJITDylib().createResourceTracker();
|
||||
theJit->addModule(std::move(m), resourceTracker);
|
||||
theJit->compileAllFunction(plugin.getFunctionNum());
|
||||
}
|
||||
|
||||
bool opFuse(CommandBuffer& cmd) {
|
||||
std::unordered_map<const Tensor*, Node*> outputTensor;
|
||||
// build graph
|
||||
|
@ -208,59 +341,7 @@ bool opFuse(CommandBuffer& cmd) {
|
|||
postDominateNodeQueue.push(child);
|
||||
}
|
||||
}
|
||||
// generate Kernel
|
||||
CPUPluginModule plugin("fuse_demo");
|
||||
for (auto compSet : fuseSets) {
|
||||
// printf("set size: %lu \n", compSet.size());
|
||||
InOutTensors tensors = plugin.addFunction(compSet);
|
||||
auto inputs = tensors.first;
|
||||
auto outputs = tensors.second;
|
||||
// build Plugin Op
|
||||
Command cmdPlugin;
|
||||
{
|
||||
std::unique_ptr<OpT> pluginOp(new OpT);
|
||||
pluginOp->type = OpType_Plugin;
|
||||
pluginOp->name = "PluginWrapper";
|
||||
PluginT* plugin_param = new PluginT;
|
||||
plugin_param->type = "PluginWrapper";
|
||||
plugin_param->attr.resize(1);
|
||||
plugin_param->attr[0].reset(new AttributeT);
|
||||
plugin_param->attr[0]->key = "kernel";
|
||||
plugin_param->attr[0]->i = plugin.getFunctionNum()-1;
|
||||
pluginOp->main.type = OpParameter_Plugin;
|
||||
pluginOp->main.value = plugin_param;
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto lastOffset = Op::Pack(builder, pluginOp.get());
|
||||
builder.Finish(lastOffset);
|
||||
cmdPlugin = GeometryComputerUtils::makeCommand(builder, inputs, outputs);
|
||||
}
|
||||
for (int i = 0; i < compSet.size(); i++) {
|
||||
auto cmd = const_cast<Command*>(compSet[i]->cmd);
|
||||
if (i == compSet.size()-1) {
|
||||
cmd->op = cmdPlugin.op;
|
||||
cmd->inputs = cmdPlugin.inputs;
|
||||
cmd->outputs = cmdPlugin.outputs;
|
||||
cmd->buffer = cmdPlugin.buffer;
|
||||
} else {
|
||||
cmd->op = nullptr;
|
||||
cmd->buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
// printf("total: %d\n", idx);
|
||||
plugin.codegen();
|
||||
// printf("cmd num: %lu \n", cmd.command.size());
|
||||
for (auto iter = cmd.command.begin(); iter != cmd.command.end();) {
|
||||
if (iter->op == nullptr) {
|
||||
iter = cmd.command.erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
#if !defined(_MSC_VER)
|
||||
// printf("cmd num: %lu \n", cmd.command.size());
|
||||
dlopen("./libplugin_fuse.so", RTLD_NOW | RTLD_LOCAL);
|
||||
#endif
|
||||
jit(cmd, fuseSets);
|
||||
return true;
|
||||
}
|
||||
} // namespace MNN
|
||||
|
|
|
@ -38,6 +38,7 @@ public:
|
|||
virtual void codegen() = 0;
|
||||
};
|
||||
|
||||
class LLVMTarget;
|
||||
#ifdef MNN_CODEGEN_CPU
|
||||
class CPUPluginModule : PluginModule{
|
||||
public:
|
||||
|
@ -49,6 +50,7 @@ public:
|
|||
InOutTensors addFunction(std::vector<Node*> nodes) override;
|
||||
const int getFunctionNum() override { return functions.size(); }
|
||||
void codegen() override;
|
||||
void codegen(LLVMTarget* target);
|
||||
private:
|
||||
class CPUPluginFunction;
|
||||
std::vector<std::unique_ptr<CPUPluginFunction>> functions;
|
||||
|
|
|
@ -21,47 +21,45 @@
|
|||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
|
||||
using namespace llvm;
|
||||
using namespace llvm::orc;
|
||||
#endif
|
||||
|
||||
class Target {
|
||||
public:
|
||||
Target() {}
|
||||
virtual ~Target() {}
|
||||
private:
|
||||
std::string name;
|
||||
};
|
||||
|
||||
#ifdef MNN_CODEGEN_LLVM
|
||||
class LLVMTarget : public Target {
|
||||
class LLVMTarget {
|
||||
public:
|
||||
LLVMTarget(std::string& name) {
|
||||
llvmBuilder = std::make_unique<IRBuilder<>>(llvmContext);
|
||||
llvmModule = std::make_unique<Module>(name, llvmContext);
|
||||
llvmModule->setTargetTriple("x86_64-apple-macosx10.15.0");
|
||||
LLVMTarget(std::string name) {
|
||||
llvmContext.reset(new LLVMContext);
|
||||
llvmBuilder = std::make_unique<IRBuilder<>>(*llvmContext.get());
|
||||
llvmModule = std::make_unique<Module>(name, *llvmContext.get());
|
||||
llvmModule->setTargetTriple("x86_64-apple-macosx11.0.0");
|
||||
}
|
||||
~LLVMTarget() override {}
|
||||
~LLVMTarget() {}
|
||||
Module* getModule() {
|
||||
return llvmModule.get();
|
||||
}
|
||||
LLVMContext& getContext() {
|
||||
return llvmContext;
|
||||
return *llvmContext.get();
|
||||
}
|
||||
IRBuilder<>* getBuilder() {
|
||||
return llvmBuilder.get();
|
||||
}
|
||||
ThreadSafeModule getThreadSafeModule() {
|
||||
return ThreadSafeModule(std::move(llvmModule), std::move(llvmContext));
|
||||
}
|
||||
private:
|
||||
LLVMContext llvmContext;
|
||||
std::unique_ptr<LLVMContext> llvmContext;
|
||||
std::unique_ptr<IRBuilder<>> llvmBuilder;
|
||||
std::unique_ptr<Module> llvmModule;
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef MNN_CODEGEN_C
|
||||
class SourceTarget : public Target {
|
||||
class SourceTarget {
|
||||
public:
|
||||
SourceTarget() {}
|
||||
~SourceTarget() override {}
|
||||
~SourceTarget() {}
|
||||
void addIndent() { indent++; }
|
||||
void subIndent() { indent--; }
|
||||
std::string getIndent() {
|
||||
|
@ -74,7 +72,7 @@ private:
|
|||
class CTarget : public SourceTarget {
|
||||
public:
|
||||
CTarget(std::string& name) {}
|
||||
~CTarget() override {}
|
||||
~CTarget() {}
|
||||
};
|
||||
#endif
|
||||
|
||||
|
|
|
@ -233,6 +233,12 @@ private:
|
|||
std::unique_ptr<FunctionAST> function;
|
||||
};
|
||||
|
||||
void CPUPluginModule::codegen(LLVMTarget* target) {
|
||||
for (int i = 0; i < getFunctionNum(); i++) {
|
||||
functions[i]->codegen(target);
|
||||
}
|
||||
}
|
||||
|
||||
void CPUPluginModule::codegen() {
|
||||
std::ofstream headerFile("./kernel.h");
|
||||
std::ofstream sourceFile("./kernel.c");
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
//
|
||||
// JitPluginWrapper.cpp
|
||||
// Codegen
|
||||
//
|
||||
// Created by MNN on 2021/01/29.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#include "jit/LLVMJit.hpp"
|
||||
#include "MNN/plugin/PluginKernel.hpp"
|
||||
#include "cpu/CPUAst.hpp"
|
||||
#include <vector>
|
||||
|
||||
MNN_PUBLIC int _intPluginWrapper = 10; // Just for linking successfully.
|
||||
|
||||
using namespace llvm;
|
||||
using namespace llvm::orc;
|
||||
|
||||
namespace MNN {
|
||||
namespace plugin {
|
||||
|
||||
namespace backend {
|
||||
class JitPluginWrapper : public CPUComputeKernel {
|
||||
public:
|
||||
bool init(CPUKernelContext*) override { return true; }
|
||||
bool compute(CPUKernelContext* ctx) override;
|
||||
};
|
||||
|
||||
bool JitPluginWrapper::compute(CPUKernelContext* ctx) {
|
||||
int kernelIdx = 0;
|
||||
if (ctx->hasAttr("kernel")) {
|
||||
kernelIdx = ctx->getAttr("kernel")->i();
|
||||
}
|
||||
|
||||
LLVMJIT* jit = LLVMJIT::createLLVMJIT();
|
||||
MNN_ASSERT(jit != nullptr);
|
||||
|
||||
int I = ctx->inputs().size();
|
||||
float** inputs = new float*[I];
|
||||
for (int i = 0; i < I; i++) {
|
||||
inputs[i] = reinterpret_cast<float*>(ctx->input(i)->buffer().host);
|
||||
}
|
||||
int O = ctx->outputs().size();
|
||||
float** outputs = new float*[O];
|
||||
for (int i = 0; i < O; i++) {
|
||||
outputs[i] = reinterpret_cast<float*>(ctx->output(i)->buffer().host);
|
||||
}
|
||||
void (*kernel)(float**, float**) = (void (*)(float**, float**))jit->getFuncByIdx(kernelIdx);
|
||||
kernel(inputs, outputs);
|
||||
return true;
|
||||
}
|
||||
} // namespace backend
|
||||
|
||||
REGISTER_PLUGIN_COMPUTE_KERNEL(JitPluginWrapper, backend::JitPluginWrapper);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace MNN
|
|
@ -0,0 +1,187 @@
|
|||
//
|
||||
// LLVMJit.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2021/2/2.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include "jit/LLVMJit.hpp"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
|
||||
|
||||
#include "llvm/ExecutionEngine/ObjectCache.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <memory>
|
||||
class MCJITObjectCache : public ObjectCache {
|
||||
public:
|
||||
MCJITObjectCache() {
|
||||
sys::fs::current_path(CacheDir);
|
||||
sys::path::append(CacheDir, "mnn_object_cache");
|
||||
}
|
||||
|
||||
virtual ~MCJITObjectCache() {}
|
||||
|
||||
bool isCached(std::string moduleId) {
|
||||
SmallString<128> IRCacheFile = CacheDir;
|
||||
sys::path::append(IRCacheFile, moduleId);
|
||||
return sys::fs::exists(IRCacheFile.str());
|
||||
}
|
||||
|
||||
virtual void notifyObjectCompiled(const Module *M, MemoryBufferRef Obj) {
|
||||
const std::string ModuleID = M->getModuleIdentifier();
|
||||
|
||||
if (0 == ModuleID.compare(0, 4, "jit-")) {
|
||||
std::string IRFileName = ModuleID;
|
||||
SmallString<128>IRCacheFile = CacheDir;
|
||||
sys::path::append(IRCacheFile, IRFileName);
|
||||
if (!sys::fs::exists(CacheDir.str()) && sys::fs::create_directory(CacheDir.str())) {
|
||||
fprintf(stderr, "Unable to create cache directory\n");
|
||||
return;
|
||||
}
|
||||
std::error_code ec;
|
||||
raw_fd_ostream IRObjectFile(IRCacheFile.c_str(), ec, sys::fs::F_None);
|
||||
IRObjectFile << Obj.getBuffer();
|
||||
}
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<MemoryBuffer> getObject(const Module* M) {
|
||||
if (!isCached(M->getModuleIdentifier())) {
|
||||
return nullptr;
|
||||
}
|
||||
SmallString<128> IRCacheFile = CacheDir;
|
||||
sys::path::append(IRCacheFile, M->getModuleIdentifier());
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> IRObjectBuffer = MemoryBuffer::getFile(IRCacheFile.c_str(), -1, false);
|
||||
if (!IRObjectBuffer) {
|
||||
return nullptr;
|
||||
}
|
||||
return MemoryBuffer::getMemBufferCopy(IRObjectBuffer.get()->getBuffer());
|
||||
}
|
||||
|
||||
private:
|
||||
SmallString<128> CacheDir;
|
||||
};
|
||||
|
||||
static MCJITObjectCache cacheObj;
|
||||
LLVMJIT* LLVMJIT::llvmJit = nullptr;
|
||||
|
||||
LLVMJIT::LLVMJIT(std::unique_ptr<TargetProcessControl> tpc, std::unique_ptr<ExecutionSession> es, JITTargetMachineBuilder jtmb, DataLayout dl)
|
||||
: processControl(std::move(tpc)), executionSession(std::move(es)), dataLayout(std::move(dl)),
|
||||
mangle(*this->executionSession, this->dataLayout),
|
||||
objectLayer(*this->executionSession, []() { return std::make_unique<SectionMemoryManager>(); }),
|
||||
compileLayer(*this->executionSession, objectLayer, std::make_unique<ConcurrentIRCompiler>(std::move(jtmb))),
|
||||
optimizeLayer(*this->executionSession, compileLayer, optimizeModule),
|
||||
mainJD(this->executionSession->createBareJITDylib("<main>")) {
|
||||
mainJD.addGenerator(cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(dl.getGlobalPrefix())));
|
||||
}
|
||||
|
||||
LLVMJIT::~LLVMJIT() {
|
||||
if (auto Err = executionSession->endSession()) {
|
||||
executionSession->reportError(std::move(Err));
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMJIT::addModule(ThreadSafeModule tsm, ResourceTrackerSP rt) {
|
||||
if (!rt) {
|
||||
rt = mainJD.getDefaultResourceTracker();
|
||||
}
|
||||
ExitOnErr(optimizeLayer.add(rt, std::move(tsm)));
|
||||
}
|
||||
|
||||
Expected<JITEvaluatedSymbol> LLVMJIT::lookup(StringRef Name) {
|
||||
return executionSession->lookup({&mainJD}, mangle(Name.str()));
|
||||
}
|
||||
|
||||
void LLVMJIT::compileAllFunction(int num) {
|
||||
auto comp = static_cast<ConcurrentIRCompiler*>(&compileLayer.getCompiler());
|
||||
comp->setObjectCache(&cacheObj);
|
||||
functions.resize(num);
|
||||
for (int i = 0; i < num; i++) {
|
||||
functions[i] = getFuncByName("kernel_" + std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t LLVMJIT::getFuncByName(std::string name) {
|
||||
return ExitOnErr(lookup(name)).getAddress();
|
||||
}
|
||||
|
||||
uint64_t LLVMJIT::getFuncByIdx(int idx) {
|
||||
if (functions.size() <= idx) {
|
||||
return 0;
|
||||
}
|
||||
return functions[idx];
|
||||
}
|
||||
|
||||
LLVMJIT* LLVMJIT::createLLVMJIT() {
|
||||
if (llvmJit != nullptr) {
|
||||
return llvmJit;
|
||||
}
|
||||
InitializeNativeTarget();
|
||||
InitializeNativeTargetAsmPrinter();
|
||||
InitializeNativeTargetAsmParser();
|
||||
auto tpc = SelfTargetProcessControl::Create();
|
||||
if (!tpc) {
|
||||
return nullptr;
|
||||
}
|
||||
auto es = std::make_unique<ExecutionSession>();
|
||||
JITTargetMachineBuilder jtmb((*tpc)->getTargetTriple());
|
||||
auto dl = jtmb.getDefaultDataLayoutForTarget();
|
||||
if (!dl) {
|
||||
return nullptr;
|
||||
}
|
||||
llvmJit = new LLVMJIT(std::move(*tpc), std::move(es), std::move(jtmb), std::move(*dl));
|
||||
return llvmJit;
|
||||
}
|
||||
|
||||
TargetMachine* LLVMJIT::GetTargetMachine(Triple TheTriple) {
|
||||
std::string Error;
|
||||
const Target *TheTarget = TargetRegistry::lookupTarget(codegen::getMArch(), TheTriple, Error);
|
||||
if (!TheTarget) {
|
||||
return nullptr;
|
||||
}
|
||||
return TheTarget->createTargetMachine(TheTriple.getTriple(), codegen::getCPUStr(), codegen::getFeaturesStr(), codegen::InitTargetOptionsFromCodeGenFlags(TheTriple),
|
||||
codegen::getExplicitRelocModel(), codegen::getExplicitCodeModel(), CodeGenOpt::Aggressive);
|
||||
}
|
||||
|
||||
Expected<ThreadSafeModule> LLVMJIT::optimizeModule(ThreadSafeModule tsm, const MaterializationResponsibility &mr) {
|
||||
static codegen::RegisterCodeGenFlags CFG;
|
||||
tsm.withModuleDo([](Module &m) {
|
||||
if (cacheObj.isCached(m.getModuleIdentifier())) {
|
||||
return;
|
||||
}
|
||||
auto modulePassManager = std::make_unique<legacy::PassManager>();
|
||||
auto funcPassManager = std::make_unique<legacy::FunctionPassManager>(&m);
|
||||
{
|
||||
Triple moduleTriple(m.getTargetTriple());
|
||||
TargetMachine *Machine = nullptr;
|
||||
if (moduleTriple.getArch()) {
|
||||
Machine = GetTargetMachine(moduleTriple);
|
||||
}
|
||||
modulePassManager->add(createTargetTransformInfoWrapperPass(Machine->getTargetIRAnalysis()));
|
||||
funcPassManager->add(createTargetTransformInfoWrapperPass(Machine->getTargetIRAnalysis()));
|
||||
PassManagerBuilder builder;
|
||||
builder.OptLevel = 3;
|
||||
builder.SizeLevel = 0;
|
||||
// builder.Inliner = createFunctionInliningPass(3, 0, false);
|
||||
builder.DisableUnrollLoops = false;
|
||||
builder.LoopVectorize = true;
|
||||
builder.SLPVectorize = true;
|
||||
builder.populateFunctionPassManager(*funcPassManager.get());
|
||||
builder.populateModulePassManager(*modulePassManager.get());
|
||||
funcPassManager->doInitialization();
|
||||
for (auto &function : m) {
|
||||
funcPassManager->run(function);
|
||||
}
|
||||
funcPassManager->doFinalization();
|
||||
modulePassManager->run(m);
|
||||
}
|
||||
});
|
||||
return std::move(tsm);
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
//
|
||||
// LLVMJit.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2021/2/2.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include "llvm/IR/DataLayout.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h"
|
||||
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||
#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
|
||||
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace llvm::orc;
|
||||
|
||||
class LLVMJIT {
|
||||
public:
|
||||
LLVMJIT(std::unique_ptr<TargetProcessControl> tpc, std::unique_ptr<ExecutionSession> es, JITTargetMachineBuilder jtmb, DataLayout dl);
|
||||
|
||||
~LLVMJIT();
|
||||
|
||||
static LLVMJIT* createLLVMJIT();
|
||||
|
||||
const DataLayout &getDataLayout() const { return dataLayout; }
|
||||
|
||||
JITDylib &getMainJITDylib() { return mainJD; }
|
||||
|
||||
void addModule(ThreadSafeModule tsm, ResourceTrackerSP rt = nullptr);
|
||||
|
||||
Expected<JITEvaluatedSymbol> lookup(StringRef Name);
|
||||
|
||||
void compileAllFunction(int num);
|
||||
|
||||
uint64_t getFuncByName(std::string name);
|
||||
|
||||
uint64_t getFuncByIdx(int idx);
|
||||
private:
|
||||
static TargetMachine* GetTargetMachine(Triple TheTriple);
|
||||
static Expected<ThreadSafeModule> optimizeModule(ThreadSafeModule tsm, const MaterializationResponsibility &mr);
|
||||
private:
|
||||
std::unique_ptr<TargetProcessControl> processControl;
|
||||
std::unique_ptr<ExecutionSession> executionSession;
|
||||
std::vector<uint64_t> functions;
|
||||
RTDyldObjectLinkingLayer objectLayer;
|
||||
IRCompileLayer compileLayer;
|
||||
IRTransformLayer optimizeLayer;
|
||||
DataLayout dataLayout;
|
||||
MangleAndInterner mangle;
|
||||
JITDylib &mainJD;
|
||||
ExitOnError ExitOnErr;
|
||||
Triple targetTriple;
|
||||
static LLVMJIT* llvmJit;
|
||||
};
|
|
@ -18,6 +18,7 @@
|
|||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include <MNN/Interpreter.hpp>
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
|
@ -32,39 +33,28 @@ int main(int argc, const char* argv[]) {
|
|||
MNN_PRINT("Usage: ./segment.out model.mnn input.jpg output.jpg\n");
|
||||
return 0;
|
||||
}
|
||||
auto net = Variable::getInputAndOutput(Variable::loadMap(argv[1]));
|
||||
if (net.first.empty()) {
|
||||
std::shared_ptr<Interpreter> net;
|
||||
net.reset(Interpreter::createFromFile(argv[1]));
|
||||
if (net == nullptr) {
|
||||
MNN_ERROR("Invalid Model\n");
|
||||
return 0;
|
||||
}
|
||||
auto input = net.first.begin()->second;
|
||||
auto info = input->getInfo();
|
||||
if (nullptr == info) {
|
||||
MNN_ERROR("The model don't have init dim\n");
|
||||
return 0;
|
||||
ScheduleConfig config;
|
||||
auto session = net->createSession(config);
|
||||
auto input = net->getSessionInput(session, nullptr);
|
||||
auto shape = input->shape();
|
||||
if (shape[0] != 1) {
|
||||
shape[0] = 1;
|
||||
net->resizeTensor(input, shape);
|
||||
net->resizeSession(session);
|
||||
}
|
||||
auto shape = input->getInfo()->dim;
|
||||
shape[0] = 1;
|
||||
input->resize(shape);
|
||||
auto output = net.second.begin()->second;
|
||||
if (nullptr == output->getInfo()) {
|
||||
MNN_ERROR("Alloc memory or compute size error\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
{
|
||||
int size_w = 0;
|
||||
int size_h = 0;
|
||||
int bpp = 0;
|
||||
if (info->order == NHWC) {
|
||||
bpp = shape[3];
|
||||
size_h = shape[1];
|
||||
size_w = shape[2];
|
||||
} else {
|
||||
bpp = shape[1];
|
||||
size_h = shape[2];
|
||||
size_w = shape[3];
|
||||
}
|
||||
bpp = shape[1];
|
||||
size_h = shape[2];
|
||||
size_w = shape[3];
|
||||
if (bpp == 0)
|
||||
bpp = 1;
|
||||
if (size_h == 0)
|
||||
|
@ -97,47 +87,44 @@ int main(int argc, const char* argv[]) {
|
|||
|
||||
std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));
|
||||
pretreat->setMatrix(trans);
|
||||
pretreat->convert((uint8_t*)inputImage, width, height, 0, input->writeMap<float>(), size_w, size_h, 4, 0, halide_type_of<float>());
|
||||
pretreat->convert((uint8_t*)inputImage, width, height, 0, input->host<float>(), size_w, size_h, 4, 0, halide_type_of<float>());
|
||||
stbi_image_free(inputImage);
|
||||
input->unMap();
|
||||
}
|
||||
// Run model
|
||||
net->runSession(session);
|
||||
|
||||
// Post treat by MNN-Express
|
||||
{
|
||||
//auto originOrder = output->getInfo()->order;
|
||||
output = _Convert(output, NHWC);
|
||||
//output = _Softmax(output, -1);
|
||||
auto outputInfo = output->getInfo();
|
||||
auto width = outputInfo->dim[2];
|
||||
auto height = outputInfo->dim[1];
|
||||
auto channel = outputInfo->dim[3];
|
||||
std::shared_ptr<Tensor> wrapTensor(ImageProcess::createImageTensor<uint8_t>(width, height, 4, nullptr));
|
||||
MNN_PRINT("Mask: w=%d, h=%d, c=%d\n", width, height, channel);
|
||||
auto outputHostPtr = output->readMap<float>();
|
||||
for (int y = 0; y < height; ++y) {
|
||||
auto rgbaY = wrapTensor->host<uint8_t>() + 4 * y * width;
|
||||
auto sourceY = outputHostPtr + y * width * channel;
|
||||
for (int x=0; x<width; ++x) {
|
||||
auto sourceX = sourceY + channel * x;
|
||||
int index = 0;
|
||||
float maxValue = sourceX[0];
|
||||
auto rgba = rgbaY + 4 * x;
|
||||
for (int c=1; c<channel; ++c) {
|
||||
if (sourceX[c] > maxValue) {
|
||||
index = c;
|
||||
maxValue = sourceX[c];
|
||||
}
|
||||
}
|
||||
rgba[0] = 255;
|
||||
rgba[2] = 0;
|
||||
rgba[1] = 0;
|
||||
rgba[3] = 255;
|
||||
if (15 == index) {
|
||||
rgba[2] = 255;
|
||||
rgba[3] = 0;
|
||||
}
|
||||
}
|
||||
/* Create VARP by tensor Begin*/
|
||||
auto outputTensor = net->getSessionOutput(session, nullptr);
|
||||
// First Create a Expr, then create Variable by the 0 index of expr
|
||||
auto output = Variable::create(Expr::create(outputTensor));
|
||||
if (nullptr == output->getInfo()) {
|
||||
MNN_ERROR("Alloc memory or compute size error\n");
|
||||
return 0;
|
||||
}
|
||||
output->unMap();
|
||||
stbi_write_png(argv[3], width, height, 4, wrapTensor->host<uint8_t>(), 4 * width);
|
||||
/* Create VARP by tensor End*/
|
||||
|
||||
// Turn dataFormat to NHWC for easy to run TopKV2
|
||||
output = _Convert(output, NHWC);
|
||||
auto width = output->getInfo()->dim[2];
|
||||
auto height = output->getInfo()->dim[1];
|
||||
auto channel = output->getInfo()->dim[3];
|
||||
MNN_PRINT("output w = %d, h=%d\n", width, height);
|
||||
|
||||
const int humanIndex = 15;
|
||||
output = _Reshape(output, {-1, channel});
|
||||
auto kv = _TopKV2(output, _Scalar<int>(1));
|
||||
// Use indice in TopKV2's C axis
|
||||
auto index = kv[1];
|
||||
// If is human, set 255, else set 0
|
||||
auto mask = _Select(_Equal(index, _Scalar<int>(humanIndex)), _Scalar<int>(255), _Scalar<int>(0));
|
||||
|
||||
//If need faster, use this code
|
||||
//auto mask = _Equal(index, _Scalar<int>(humanIndex)) * _Scalar<int>(255);
|
||||
|
||||
mask = _Cast<uint8_t>(mask);
|
||||
stbi_write_png(argv[3], width, height, 1, mask->readMap<uint8_t>(), width);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "Utils.hpp"
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include "core/WrapExecution.hpp"
|
||||
#include "core/OpCommonUtils.hpp"
|
||||
#include "geometry/GeometryComputerUtils.hpp"
|
||||
#include <MNN/expr/ExecutorScope.hpp>
|
||||
#ifdef MNN_EXPR_ENABLE_PROFILER
|
||||
|
@ -127,10 +128,10 @@ Executor::Requirement Executor::getRequirement(Expr* expr) const {
|
|||
return req;
|
||||
}
|
||||
for (int i = 0; i < inputSize; ++i) {
|
||||
req.contentNeedContent[i] = SizeComputer::opNeedContent(op->type(), i);
|
||||
req.contentNeedContent[i] = OpCommonUtils::opNeedContent(op->type(), i);
|
||||
req.shapeNeedContent[i] = false;
|
||||
}
|
||||
auto needIndexId = SizeComputer::needInputContent(op);
|
||||
auto needIndexId = SizeComputer::needInputContent(op, inputSize);
|
||||
for (auto index : needIndexId) {
|
||||
if (index < req.shapeNeedContent.size()) {
|
||||
req.shapeNeedContent[index] = true;
|
||||
|
@ -440,7 +441,7 @@ ErrorCode Executor::ComputeCache::resize() {
|
|||
op = flatbuffers::GetMutableRoot<Op>(cmd.buffer.data());
|
||||
}
|
||||
for (auto v = 0; v<cmd.inputs.size(); ++v) {
|
||||
if (!SizeComputer::opNeedContent(op->type(), v)) {
|
||||
if (!OpCommonUtils::opNeedContent(op->type(), v)) {
|
||||
continue;
|
||||
}
|
||||
auto des = TensorUtils::getDescribe(cmd.inputs[v]);
|
||||
|
@ -495,7 +496,7 @@ ErrorCode Executor::ComputeCache::resize() {
|
|||
auto bn = mExecutions[k]->backend();
|
||||
auto iterType = bn->type();
|
||||
for (int i=0; i<cmd.inputs.size(); ++i) {
|
||||
if (!SizeComputer::opNeedContent(op->type(), i)) {
|
||||
if (!OpCommonUtils::opNeedContent(op->type(), i)) {
|
||||
continue;
|
||||
}
|
||||
auto inpDes = TensorUtils::getDescribe(cmd.inputs[i]);
|
||||
|
@ -550,7 +551,7 @@ ErrorCode Executor::ComputeCache::resize() {
|
|||
return code;
|
||||
}
|
||||
for (auto v = 0; v<cmd.inputs.size(); ++v) {
|
||||
if (!SizeComputer::opNeedContent(op->type(), v)) {
|
||||
if (!OpCommonUtils::opNeedContent(op->type(), v)) {
|
||||
continue;
|
||||
}
|
||||
auto t = cmd.inputs[v];
|
||||
|
|
|
@ -99,8 +99,8 @@ Expr::Expr(int outputSize) {
|
|||
mInside.reset(new Inside(outputSize));
|
||||
mOutputNames.resize(outputSize);
|
||||
}
|
||||
Expr::Expr(Tensor* tensor) {
|
||||
mInside.reset(new Inside(tensor));
|
||||
Expr::Expr(Tensor* tensor, bool own) {
|
||||
mInside.reset(new Inside(tensor, own));
|
||||
mOutputNames.resize(1);
|
||||
}
|
||||
|
||||
|
@ -129,8 +129,8 @@ void Expr::_addLinkForInputs(EXPRP expr) {
|
|||
}
|
||||
}
|
||||
}
|
||||
EXPRP Expr::create(Tensor* tensor) {
|
||||
EXPRP expr(new Expr(tensor));
|
||||
EXPRP Expr::create(Tensor* tensor, bool own) {
|
||||
EXPRP expr(new Expr(tensor, own));
|
||||
expr->mOp = nullptr;
|
||||
expr->mType = VARP::CONSTANT;
|
||||
auto& dstInfo = expr->mInside->mOutputInfos[0];
|
||||
|
@ -566,8 +566,11 @@ void* Variable::readInternal(bool forShape) {
|
|||
auto inside = mFrom->inside();
|
||||
auto originTensor = inside->mOutputTensors[0];
|
||||
if (0 != originTensor->buffer().device) {
|
||||
// For StaticModule will other-device runtime, we may create Variable with other-device's memory
|
||||
// The case won't occured for varibale = INPUT
|
||||
// Need Copy
|
||||
if (nullptr != inside->mHostTensor) {
|
||||
// The Varp will not be created as input, so we just need copy once
|
||||
return inside->mHostTensor->host<void>();
|
||||
}
|
||||
inside->mHostTensor = new Tensor;
|
||||
|
@ -838,7 +841,7 @@ void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
|
|||
auto& info = expr->mInside->mOutputInfos[0];
|
||||
const void* ptr = expr->mInside->mOutputTensors[0]->host<void>();
|
||||
VARP temp;
|
||||
if (nullptr == ptr) {
|
||||
if (nullptr == ptr || expr->mInside->mOutputTensors[0]->deviceId() > 0) {
|
||||
temp = Variable::create(expr);
|
||||
ptr = temp->readMap<void>();
|
||||
}
|
||||
|
|
|
@ -392,12 +392,15 @@ output: A variable with the same type as `x`.
|
|||
*/
|
||||
VARP _Reshape(VARP x, VARP shape) {
|
||||
MNN_ASSERT(nullptr != x);
|
||||
MNN_ASSERT(nullptr != x->getInfo());
|
||||
std::unique_ptr<OpT> reshape(new OpT);
|
||||
reshape->type = OpType_Reshape;
|
||||
reshape->main.type = OpParameter_Reshape;
|
||||
reshape->main.value = new ReshapeT;
|
||||
reshape->main.AsReshape()->dimType = (MNN_DATA_FORMAT)Utils::convertFormat(x->getInfo()->order);
|
||||
if (nullptr != x->getInfo()) {
|
||||
reshape->main.AsReshape()->dimType = (MNN_DATA_FORMAT)Utils::convertFormat(x->getInfo()->order);
|
||||
} else {
|
||||
reshape->main.AsReshape()->dimType = MNN_DATA_FORMAT_NHWC;
|
||||
}
|
||||
return (Variable::create(Expr::create(reshape.get(), {x, shape})));
|
||||
}
|
||||
VARP _Scale(VARP x, int channels, std::vector<float>&& scales, std::vector<float>&& bias) {
|
||||
|
@ -425,7 +428,7 @@ VARP _Relu(VARP x, float slope) {
|
|||
relu->main.AsRelu()->slope = slope;
|
||||
return (Variable::create(Expr::create(relu.get(), {x})));
|
||||
}
|
||||
/*Given an input value x, it computes Rectified Linear 6: min(max(x, 0), 6).
|
||||
/*Given an input value x, it computes Rectified Linear 6: min(max(x, 0), 6).
|
||||
Args:
|
||||
x: A variable.
|
||||
Returns:
|
||||
|
@ -1562,6 +1565,36 @@ VARP _CosineSimilarity(VARP input0, VARP input1, VARP inputDim) {
|
|||
return (Variable::create(Expr::create(std::move(cosineSimilarityOp), {input0, input1, inputDim})));
|
||||
}
|
||||
|
||||
VARP _GridSample(VARP input, VARP grid, InterpolationMethod mode, GridSamplePaddingMode paddingMode, bool alignCorners) {
|
||||
std::unique_ptr<OpT> op(new OpT);
|
||||
op->type = OpType_GridSample;
|
||||
op->main.type = OpParameter_GridSample;
|
||||
op->main.value = new GridSampleT;
|
||||
switch (mode) {
|
||||
case NEAREST:
|
||||
op->main.AsGridSample()->mode = SampleMode_NEAREST;
|
||||
break;
|
||||
case BILINEAR:
|
||||
default:
|
||||
op->main.AsGridSample()->mode = SampleMode_BILINEAR;
|
||||
break;
|
||||
}
|
||||
switch (paddingMode) {
|
||||
case GRID_SAMPLE_PADDING_BORDER:
|
||||
op->main.AsGridSample()->paddingMode = BorderMode_CLAMP;
|
||||
break;
|
||||
case GRID_SAMPLE_PADDING_REFLECTION:
|
||||
op->main.AsGridSample()->paddingMode = BorderMode_REFLECTION;
|
||||
break;
|
||||
case GRID_SAMPLE_PADDING_ZEROS:
|
||||
default:
|
||||
op->main.AsGridSample()->paddingMode = BorderMode_ZEROS;
|
||||
break;
|
||||
}
|
||||
op->main.AsGridSample()->alignCorners = alignCorners;
|
||||
return (Variable::create(Expr::create(std::move(op), {input, grid})));
|
||||
}
|
||||
|
||||
VARP _FloatToInt8(VARP x, VARP scale, char minValue/*For future*/, char maxValue/*For future*/) {
|
||||
auto xInfo = x->getInfo();
|
||||
auto scaleInfo = scale->getInfo();
|
||||
|
@ -1574,7 +1607,7 @@ VARP _FloatToInt8(VARP x, VARP scale, char minValue/*For future*/, char maxValue
|
|||
MNN_ERROR("Not Support Input for FloatToInt8 because var not NC4HW4 or not float\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (scaleInfo->size != xInfo->dim[1]) {
|
||||
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
|
||||
MNN_ERROR("Scale's size not match input's channel: %d - %d\n", scaleInfo->size, xInfo->dim[1]);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1599,7 +1632,7 @@ VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t z
|
|||
MNN_ERROR("Not Support Input for FloatToInt8 because var not NC4HW4 or not float\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (scaleInfo->size != xInfo->dim[1]) {
|
||||
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
|
||||
MNN_ERROR("Scale's size not match input's channel: %d - %d\n", scaleInfo->size, xInfo->dim[1]);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1628,7 +1661,7 @@ VARP _Int8ToFloat(VARP x, VARP scale) {
|
|||
MNN_ERROR("Not Support Input for _Int8ToFloat because var not NC4HW4 or not int8\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (scaleInfo->size != xInfo->dim[1]) {
|
||||
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
|
||||
MNN_ERROR("_Int8ToFloat Scale's size not match input's channel\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1653,7 +1686,7 @@ VARP _Int8ToFloat(VARP x, VARP scale, int8_t zeroPoint) {
|
|||
MNN_ERROR("Not Support Input for _Int8ToFloat because var not NC4HW4 or not int8\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (scaleInfo->size != xInfo->dim[1]) {
|
||||
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
|
||||
MNN_ERROR("_Int8ToFloat Scale's size not match input's channel\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1673,5 +1706,16 @@ VARP _Select(VARP select, VARP input0, VARP input1) {
|
|||
return (Variable::create(Expr::create(std::move(selectOp), {select, input0, input1})));
|
||||
}
|
||||
|
||||
std::vector<VARP> _TopKV2(VARP input0, VARP input1) {
|
||||
std::unique_ptr<OpT> op(new OpT);
|
||||
op->type = OpType_TopKV2;
|
||||
auto expr = Expr::create(op.get(), {input0, input1}, 2);
|
||||
std::vector<VARP> res(2);
|
||||
res[0] = Variable::create(expr, 0);
|
||||
res[1] = Variable::create(expr, 1);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
||||
|
|
|
@ -25,14 +25,13 @@ Expr::Inside::Inside(int outputSize) {
|
|||
TensorUtils::getDescribe(mOutputTensors[i])->memoryType = Tensor::InsideDescribe::MEMORY_HOST;
|
||||
}
|
||||
}
|
||||
Expr::Inside::Inside(Tensor* tensor) {
|
||||
Expr::Inside::Inside(Tensor* tensor, bool own) {
|
||||
mOutputInfos.resize(1);
|
||||
mOutputTensors.resize(1);
|
||||
mOutputTensors[0] = tensor;
|
||||
Utils::copyTensorToInfo(&mOutputInfos[0], tensor);
|
||||
mOutputInfos[0].syncSize();
|
||||
mOutputInfos[0].tensorArrayAttr = TensorUtils::getDescribe(tensor)->tensorArrayAttr;
|
||||
mOwnTensor = false;
|
||||
mOwnTensor = own;
|
||||
}
|
||||
|
||||
Expr::Inside::~Inside() {
|
||||
|
|
|
@ -29,7 +29,7 @@ struct BufferStorage {
|
|||
};
|
||||
struct Expr::Inside {
|
||||
Inside(int outputSize);
|
||||
Inside(Tensor* tensor);
|
||||
Inside(Tensor* tensor, bool own = false);
|
||||
~ Inside();
|
||||
std::vector<Variable::Info> mOutputInfos;
|
||||
std::vector<Tensor*> mOutputTensors;
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
//
|
||||
// FixModule.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/12/16.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include "FixModule.hpp"
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
using namespace MNN::Express;
|
||||
namespace MNN {
|
||||
namespace Express {
|
||||
FixModule::FixModule(std::vector<Express::VARP> output, std::vector<Express::VARP> parameters,
|
||||
std::vector<std::pair<Express::VARP, Express::Dimensionformat>> inputs) {
|
||||
for (auto p : parameters) {
|
||||
addParameter(p);
|
||||
}
|
||||
mInputs = std::move(inputs);
|
||||
mOutput = std::move(output);
|
||||
}
|
||||
void FixModule::onClearCache() {
|
||||
for (auto v : mInputs) {
|
||||
v.first.fix(VARP::INPUT);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Express::VARP> FixModule::onForward(const std::vector<Express::VARP>& inputs) {
|
||||
MNN_ASSERT(inputs.size() == mInputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto var = inputs[i];
|
||||
var = _Convert(var, mInputs[i].second);
|
||||
Variable::replace(mInputs[i].first, var);
|
||||
}
|
||||
return mOutput;
|
||||
}
|
||||
|
||||
Module* FixModule::clone(CloneContext* ctx) const {
|
||||
FixModule* module(new FixModule);
|
||||
for (auto& it : mInputs) {
|
||||
VARP v = ctx->getOrClone(it.first);
|
||||
module->mInputs.push_back(std::make_pair(v, it.second));
|
||||
}
|
||||
for (auto& it : mOutput) {
|
||||
VARP v = ctx->getOrClone(it);
|
||||
module->mOutput.push_back(v);
|
||||
}
|
||||
return this->cloneBaseTo(ctx, module);
|
||||
}
|
||||
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
|
@ -1,33 +0,0 @@
|
|||
//
|
||||
// FixModule.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/12/16.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifndef FixModule_hpp
|
||||
#define FixModule_hpp
|
||||
#include <MNN/expr/Module.hpp>
|
||||
namespace MNN {
|
||||
namespace Express {
|
||||
|
||||
class FixModule : public Module {
|
||||
public:
|
||||
FixModule(std::vector<Express::VARP> output, std::vector<Express::VARP> parameters,
|
||||
std::vector<std::pair<Express::VARP, Express::Dimensionformat>> inputs);
|
||||
virtual ~FixModule() = default;
|
||||
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override;
|
||||
virtual void onClearCache() override;
|
||||
private:
|
||||
FixModule() = default;
|
||||
|
||||
Module* clone(CloneContext* ctx) const override;
|
||||
|
||||
std::vector<std::pair<Express::VARP, Express::Dimensionformat>> mInputs;
|
||||
std::vector<Express::VARP> mOutput;
|
||||
};
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
#include <MNN/expr/Module.hpp>
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include "FixModule.hpp"
|
||||
#include "PipelineModule.hpp"
|
||||
#include "core/FileLoader.hpp"
|
||||
|
||||
|
@ -124,15 +123,15 @@ Module* Module::load(const std::vector<std::string>& inputs, const std::vector<s
|
|||
FileLoader loader(fileName);
|
||||
if (!loader.valid()) {
|
||||
MNN_ERROR("Error for open %s\n", fileName);
|
||||
return {};
|
||||
return nullptr;
|
||||
}
|
||||
loader.read();
|
||||
if (!loader.valid()) {
|
||||
return {};
|
||||
return nullptr;
|
||||
}
|
||||
loader.merge(buffer);
|
||||
if (buffer.get() == nullptr) {
|
||||
return {};
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return load(inputs, outputs, buffer.get(), buffer.size(), config);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
#include <MNN/expr/NN.hpp>
|
||||
#include "Distributions.hpp"
|
||||
#include "FixModule.hpp"
|
||||
#include "PipelineModule.hpp"
|
||||
#include "WhileModule.hpp"
|
||||
#include "IfModule.hpp"
|
||||
#include "Initializer.hpp"
|
||||
|
@ -364,11 +364,11 @@ Module* NN::ConvTranspose(const ConvOption& option, bool hasBias,
|
|||
if (nullptr != bias) {
|
||||
auto tempOutput = _Deconv(weight, bias, input, option.padMode, option.stride, option.dilate, group);
|
||||
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
|
||||
return new FixModule({tempOutput}, {weight, bias}, {{input, NC4HW4}});
|
||||
return PipelineModule::extract({input}, {tempOutput}, true);
|
||||
}
|
||||
auto tempOutput = _Deconv(weight, nullptr, input, option.padMode, option.stride, option.dilate, group);
|
||||
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
|
||||
return new FixModule({tempOutput}, {weight}, {{input, NC4HW4}});
|
||||
return PipelineModule::extract({input}, {tempOutput}, true);
|
||||
}
|
||||
Module* NN::Conv(const ConvOption& option, bool hasBias, std::shared_ptr<Initializer> weightInit,
|
||||
std::shared_ptr<Initializer> biasInit) {
|
||||
|
@ -397,12 +397,12 @@ Module* NN::Linear(int l, int t, bool hasBias, std::shared_ptr<Initializer> weig
|
|||
auto input = _Input({l}, NCHW);
|
||||
auto output = _MatMul(input, weight, false, true);
|
||||
if (!hasBias) {
|
||||
return new FixModule({output}, {weight}, {{input, NCHW}});
|
||||
return PipelineModule::extract({input}, {output}, true);
|
||||
}
|
||||
auto bias = biasInit->createConstVar({1, t}, NCHW);
|
||||
bias.fix(VARP::TRAINABLE);
|
||||
output = _Add(output, bias);
|
||||
auto module = new FixModule({output}, {weight, bias}, {{input, NCHW}});
|
||||
auto module = PipelineModule::extract({input}, {output}, true);
|
||||
module->setType("Linear");
|
||||
return module;
|
||||
}
|
||||
|
@ -508,133 +508,10 @@ NN::ConvParameters NN::Utils::ExtractConvolution(EXPRP source) {
|
|||
return _default;
|
||||
}
|
||||
|
||||
static int _clamp(int c, int maxValue, int minValue) {
|
||||
if (c > maxValue) {
|
||||
return maxValue;
|
||||
}
|
||||
if (c < minValue) {
|
||||
return minValue;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
class ConvOctaveModule : public Module {
|
||||
public:
|
||||
ConvOctaveModule(const NN::ConvOption& option, VARP weight, VARP bias, int group, float inFactor, float outFactor)
|
||||
: mOption(option) {
|
||||
auto inputCountC4 = UP_DIV(option.channel[0], 4);
|
||||
auto outputCountC4 = UP_DIV(option.channel[1], 4);
|
||||
MNN_ASSERT(inputCountC4 > 1 && outputCountC4 > 1);
|
||||
MNN_ASSERT(nullptr != bias);
|
||||
auto iC0 = (int)((float)inputCountC4 * inFactor);
|
||||
iC0 = _clamp(iC0, inputCountC4 - 1, 1);
|
||||
|
||||
auto oC0 = (int)((float)outputCountC4 * outFactor);
|
||||
oC0 = _clamp(oC0, outputCountC4 - 1, 1);
|
||||
|
||||
iC0 = iC0 * 4;
|
||||
auto iC1 = option.channel[0] - iC0;
|
||||
oC0 = oC0 * 4;
|
||||
auto oC1 = option.channel[1] - oC0;
|
||||
mSplitInput = {iC0, iC1};
|
||||
|
||||
MNN_PRINT("Octave: %d, %d -> %d - %d, %d-%d\n", option.channel[0], option.channel[1], iC0, iC1, oC0, oC1);
|
||||
auto splitBias = _Split(bias * _Scalar<float>(0.5f), {oC0, oC1}, 0);
|
||||
mLBias = splitBias[0];
|
||||
mHBias = splitBias[1];
|
||||
mLBias.fix(VARP::TRAINABLE);
|
||||
mHBias.fix(VARP::TRAINABLE);
|
||||
|
||||
auto splitWeight = _Split(weight, {oC0, oC1}, 0);
|
||||
auto lw = _Split(splitWeight[0], {iC0, iC1}, 1);
|
||||
auto hw = _Split(splitWeight[1], {iC0, iC1}, 1);
|
||||
mLLW = lw[0];
|
||||
mLHW = lw[1];
|
||||
mHLW = hw[0];
|
||||
mHHW = hw[1];
|
||||
|
||||
mLLW.fix(VARP::TRAINABLE);
|
||||
mLHW.fix(VARP::TRAINABLE);
|
||||
mHLW.fix(VARP::TRAINABLE);
|
||||
mHHW.fix(VARP::TRAINABLE);
|
||||
mGroup = group;
|
||||
addParameter(mLBias);
|
||||
addParameter(mHBias);
|
||||
addParameter(mLLW);
|
||||
addParameter(mLHW);
|
||||
addParameter(mHHW);
|
||||
addParameter(mHLW);
|
||||
setType("ConvOctave");
|
||||
}
|
||||
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
|
||||
auto input = _Convert(inputs[0], NC4HW4);
|
||||
auto inputSplit = _Split(input, mSplitInput, 1);
|
||||
auto XL = inputSplit[0];
|
||||
auto XH = inputSplit[1];
|
||||
if (input->getInfo()->dim[3] < 2) {
|
||||
auto L2L = _Conv(mLLW, mLBias, XL, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto L2H = _Conv(mHLW, mHBias, XL, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto H2L = _Conv(mLHW, mLBias, XH, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto H2H = _Conv(mHHW, mHBias, XH, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto L = L2L + H2L;
|
||||
auto H = H2H + L2H;
|
||||
return {_Concat({L, H}, 1)};
|
||||
}
|
||||
XL = _AvePool(XL, {2, 2}, {2, 2});
|
||||
auto info = XL->getInfo();
|
||||
auto L2L = _Conv(mLLW, mLBias, XL, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto L2H = _Conv(mHLW, mHBias, XL, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto H2L =
|
||||
_Conv(mLHW, mLBias, _AvePool(XH, {2, 2}, {2, 2}), mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto H2H = _Conv(mHHW, mHBias, XH, mOption.padMode, mOption.stride, mOption.dilate, mGroup);
|
||||
auto L = L2L + H2L;
|
||||
auto H = H2H;
|
||||
auto dstShape = H->getInfo()->dim; // NCHW
|
||||
{ H = H2H + _Interp({L2H}, 0.0f, 0.0f, dstShape[3], dstShape[2], 1, true); }
|
||||
auto res = _Concat({_Interp({L}, 0.0f, 0.0f, dstShape[3], dstShape[2], 1, true), H}, 1);
|
||||
info = res->getInfo();
|
||||
MNN_ASSERT(nullptr != info);
|
||||
return {_activate(res, mOption.fusedActivationFunction)};
|
||||
}
|
||||
|
||||
private:
|
||||
ConvOctaveModule() = default;
|
||||
|
||||
Module* clone(CloneContext* ctx) const override {
|
||||
ConvOctaveModule* module(new ConvOctaveModule);
|
||||
module->mOption = mOption;
|
||||
module->mLLW = ctx->getOrClone(mLLW);
|
||||
module->mLHW = ctx->getOrClone(mLHW);
|
||||
module->mHLW = ctx->getOrClone(mHLW);
|
||||
module->mHHW = ctx->getOrClone(mHHW);
|
||||
module->mLBias = ctx->getOrClone(mLBias);
|
||||
module->mHBias = ctx->getOrClone(mHBias);
|
||||
module->mSplitInput = mSplitInput;
|
||||
module->mGroup = mGroup;
|
||||
return this->cloneBaseTo(ctx, module);
|
||||
}
|
||||
|
||||
NN::ConvOption mOption;
|
||||
VARP mLLW;
|
||||
VARP mLHW;
|
||||
VARP mHLW;
|
||||
VARP mHHW;
|
||||
VARP mLBias;
|
||||
VARP mHBias;
|
||||
|
||||
std::vector<int> mSplitInput;
|
||||
int mGroup;
|
||||
};
|
||||
|
||||
Module* NN::Conv(const ConvParameters& parameter) {
|
||||
return new ConvModule(parameter);
|
||||
}
|
||||
|
||||
Module* NN::ConvOctave(const ConvParameters& parameters,
|
||||
float inFactor, float outFactor) {
|
||||
auto module = new ConvOctaveModule(parameters.option, parameters.weight, parameters.bias, parameters.group, inFactor, outFactor);
|
||||
module->setName(parameters.name);
|
||||
return module;
|
||||
}
|
||||
Module* NN::Utils::ExtractNotRunableOp(Express::EXPRP expr, const std::map<std::string, SubGraph>& subgraphs) {
|
||||
if (nullptr == expr->get()) {
|
||||
return nullptr;
|
||||
|
@ -701,46 +578,90 @@ public:
|
|||
mActivation = mOption.fusedActivationFunction;
|
||||
}
|
||||
|
||||
mFeatureScaleStatMethod = featureScaleStatMethod;
|
||||
if (featureScaleStatMethod == NN::PerChannel) {
|
||||
MNN_PRINT("PerChannel quantization for feature is deprecated, use PerTensor method instead.\n");
|
||||
return;
|
||||
}
|
||||
|
||||
mFeatureScaleStatMethod = NN::PerTensor;
|
||||
mScaleUpdateMethod = scaleUpdateMethod;
|
||||
|
||||
mBits = bits;
|
||||
auto limit = (float)(1 << (bits - 1)) - 1.0f;
|
||||
mLimitScale = _Scalar<float>(1.0f / limit);
|
||||
mClampValue = _Scalar<float>(limit);
|
||||
mLimit = (float)(1 << (bits - 1)) - 1.0f;
|
||||
mLimitScale = _Scalar<float>(1.0f / mLimit);
|
||||
mWeightClampValue = _Scalar<float>(mLimit);
|
||||
mInputClampValue = _Scalar<float>(mLimit);
|
||||
mOutputClampValue = _Scalar<float>(mLimit);
|
||||
|
||||
mInputScalePos = addParameter(mInputScale);
|
||||
mOutputScalePos = addParameter(mOutputScale);
|
||||
mInputMinPos = addParameter(mInputMin);
|
||||
mInputMaxPos = addParameter(mInputMax);
|
||||
mOutputMinPos = addParameter(mOutputMin);
|
||||
mOutputMaxPos = addParameter(mOutputMax);
|
||||
|
||||
setType("ConvBNReluFused");
|
||||
}
|
||||
|
||||
std::pair<VARP, VARP> fakeQuantFeature(VARP x, VARP useScale = nullptr) {
|
||||
std::pair<VARP, VARP> computeScaleAndZeroPoint(VARP min, VARP max, VARP clampVar) {
|
||||
MNN_ASSERT((!(min == nullptr)));
|
||||
MNN_ASSERT((!(max == nullptr)));
|
||||
|
||||
min = _Minimum(_Scalar<float>(0.0f), min);
|
||||
max = _Maximum(_Scalar<float>(0.0f), max);
|
||||
|
||||
auto scale = (max - min) / (_Scalar(2.0f) * clampVar);
|
||||
auto zeroPoint = _Round((_Scalar(0.0f) - min) / scale - clampVar);
|
||||
|
||||
return std::make_pair(scale, zeroPoint);
|
||||
}
|
||||
|
||||
std::vector<VARP> fakeQuantFeatureWithMinMax(VARP x, VARP useMin, VARP useMax, VARP clampVar) {
|
||||
auto originFormat = x->getInfo()->order;
|
||||
auto tempX = x;
|
||||
if (originFormat == NC4HW4) {
|
||||
tempX = _Convert(tempX, NCHW);
|
||||
}
|
||||
auto originX = tempX;
|
||||
VARP scale = _Maximum(_ReduceMax(_Abs(tempX)), _Scalar<float>(0.0001f)) * mLimitScale;
|
||||
if (useScale == nullptr) {
|
||||
tempX = _Round(tempX * _Reciprocal(scale)) * scale;
|
||||
VARP min, max;
|
||||
// always PerTensor
|
||||
min = _ReduceMin(tempX);
|
||||
max = _ReduceMax(tempX);
|
||||
|
||||
VARP scale, zeroPoint;
|
||||
VARP nudgeMin, nudgeMax;
|
||||
|
||||
if (!(useMin == nullptr)) {
|
||||
MNN_ASSERT(!(useMax == nullptr));
|
||||
auto scaleAndZeroPoint = computeScaleAndZeroPoint(useMin, useMax, clampVar);
|
||||
scale = scaleAndZeroPoint.first;
|
||||
zeroPoint = scaleAndZeroPoint.second;
|
||||
} else {
|
||||
tempX = _Round(tempX * _Reciprocal(useScale)) * useScale;
|
||||
auto scaleAndZeroPoint = computeScaleAndZeroPoint(min, max, clampVar);
|
||||
scale = scaleAndZeroPoint.first;
|
||||
zeroPoint = scaleAndZeroPoint.second;
|
||||
}
|
||||
|
||||
float limit = clampVar->readMap<float>()[0];
|
||||
nudgeMin = (_Scalar<float>(-limit) - zeroPoint) * scale;
|
||||
nudgeMax = (_Scalar<float>(limit) - zeroPoint) * scale;
|
||||
|
||||
nudgeMin = _Minimum(_Scalar<float>(0.0f), nudgeMin);
|
||||
nudgeMax = _Maximum(_Scalar<float>(0.0f), nudgeMax);
|
||||
|
||||
auto quantX = clamp(_Round(tempX / scale + zeroPoint), clampVar);
|
||||
tempX = scale * (quantX - zeroPoint);
|
||||
// Break the grad by use cast
|
||||
tempX = _Cast<float>(tempX);
|
||||
|
||||
// Move grad from tempX to originX
|
||||
tempX = _Convert(tempX + _ZeroGrad(originX), originFormat);
|
||||
return std::make_pair(tempX, scale);
|
||||
|
||||
return {tempX, nudgeMin, nudgeMax};
|
||||
}
|
||||
|
||||
VARP clamp(VARP x) {
|
||||
return _Maximum(_Minimum(x, mClampValue), _Negative(mClampValue));
|
||||
VARP clamp(VARP x, VARP clampVar) {
|
||||
return _Maximum(_Minimum(x, clampVar), _Negative(clampVar));
|
||||
}
|
||||
|
||||
VARP updateScale(VARP originValue, VARP newValue) const {
|
||||
VARP updateParameter(VARP originValue, VARP newValue) const {
|
||||
if (nullptr == originValue) {
|
||||
return newValue;
|
||||
}
|
||||
|
@ -761,20 +682,21 @@ public:
|
|||
if (getIsTraining()) {
|
||||
auto x = _Convert(inputs[0], NCHW);
|
||||
// simulate weight quant
|
||||
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
|
||||
auto weightTemp = _Round(mWeight * _Reciprocal(weightScale)) * weightScale;
|
||||
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
|
||||
auto weightTemp = clamp(_Round(mWeight * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
|
||||
weightTemp = weightTemp + _ZeroGrad(mWeight);
|
||||
|
||||
// simulate input quant to get original input scale
|
||||
auto inputPair = fakeQuantFeature(x);
|
||||
mInputScale = updateScale(mInputScale, inputPair.second);
|
||||
setParameter(mInputScale, mInputScalePos);
|
||||
auto inputPair = fakeQuantFeatureWithMinMax(x, nullptr, nullptr, mInputClampValue);
|
||||
mInputMin = updateParameter(mInputMin, inputPair[1]);
|
||||
mInputMax = updateParameter(mInputMax, inputPair[2]);
|
||||
setParameter(mInputMin, mInputMinPos);
|
||||
setParameter(mInputMax, mInputMaxPos);
|
||||
|
||||
// simulate output quant to get original output scale
|
||||
res = _Conv(weightTemp, mBias, _Convert(inputPair.first, NC4HW4), mOption.padMode, mOption.stride,
|
||||
res = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
|
||||
mOption.dilate, mGroup, mOption.pads);
|
||||
res->setName(name());
|
||||
auto conv = res;
|
||||
|
||||
if (mBatchNorm) {
|
||||
res = mBatchNorm->forward(res);
|
||||
|
@ -782,25 +704,29 @@ public:
|
|||
|
||||
res = _activate(res, mActivation);
|
||||
|
||||
auto outputPair = fakeQuantFeature(res);
|
||||
mOutputScale = updateScale(mOutputScale, outputPair.second);
|
||||
setParameter(mOutputScale, mOutputScalePos);
|
||||
res = outputPair.first;
|
||||
auto outputPair = fakeQuantFeatureWithMinMax(res, nullptr, nullptr, mOutputClampValue);
|
||||
mOutputMin = updateParameter(mOutputMin, outputPair[1]);
|
||||
mOutputMax = updateParameter(mOutputMax, outputPair[2]);
|
||||
setParameter(mOutputMin, mOutputMinPos);
|
||||
setParameter(mOutputMax, mOutputMaxPos);
|
||||
|
||||
res = outputPair[0];
|
||||
} else {
|
||||
if (nullptr == mInputScale) {
|
||||
if (nullptr == mInputMin) {
|
||||
// Initial for test
|
||||
// simulate weight quant
|
||||
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
|
||||
weightScale.fix(VARP::CONSTANT);
|
||||
auto weightTemp = _Round(mWeight * _Reciprocal(weightScale)) * weightScale;
|
||||
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
|
||||
auto weightTemp = clamp(_Round(mWeight * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
|
||||
|
||||
auto x = _Convert(inputs[0], NCHW);
|
||||
auto inputPair = fakeQuantFeature(x);
|
||||
mInputScale = inputPair.second;
|
||||
setParameter(mInputScale, mInputScalePos);
|
||||
inputPair.first.fix(VARP::CONSTANT);
|
||||
|
||||
auto simuRes = _Conv(weightTemp, mBias, _Convert(inputPair.first, NC4HW4), mOption.padMode, mOption.stride,
|
||||
auto inputPair = fakeQuantFeatureWithMinMax(x, nullptr, nullptr, mInputClampValue);
|
||||
mInputMin = updateParameter(mInputMin, inputPair[1]);
|
||||
mInputMax = updateParameter(mInputMax, inputPair[2]);
|
||||
setParameter(mInputMin, mInputMinPos);
|
||||
setParameter(mInputMax, mInputMaxPos);
|
||||
|
||||
auto simuRes = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
|
||||
mOption.dilate, mGroup, mOption.pads);
|
||||
if (mBatchNorm) {
|
||||
simuRes = mBatchNorm->forward(simuRes);
|
||||
|
@ -808,10 +734,12 @@ public:
|
|||
simuRes = _activate(simuRes, mActivation);
|
||||
|
||||
Variable::prepareCompute({simuRes});
|
||||
auto outputPair = fakeQuantFeature(simuRes);
|
||||
mOutputScale = outputPair.second;
|
||||
setParameter(mOutputScale, mOutputScalePos);
|
||||
outputPair.first.fix(VARP::CONSTANT);
|
||||
|
||||
auto outputPair = fakeQuantFeatureWithMinMax(simuRes, nullptr, nullptr, mOutputClampValue);
|
||||
mOutputMin = updateParameter(mOutputMin, outputPair[1]);
|
||||
mOutputMax = updateParameter(mOutputMax, outputPair[2]);
|
||||
setParameter(mOutputMin, mOutputMinPos);
|
||||
setParameter(mOutputMax, mOutputMaxPos);
|
||||
}
|
||||
|
||||
// fold bn to conv weights and bias
|
||||
|
@ -833,21 +761,39 @@ public:
|
|||
|
||||
alpha = _Reshape(alpha, {alpha->getInfo()->size, 1, 1, 1});
|
||||
beta = _Reshape(beta, {beta->getInfo()->size, 1, 1, 1});
|
||||
alpha.fix(VARP::CONSTANT);
|
||||
beta.fix(VARP::CONSTANT);
|
||||
|
||||
fusedWeights = alpha * fusedWeights;
|
||||
fusedBias = alpha * fusedBias + beta;
|
||||
fusedWeights.fix(VARP::CONSTANT);
|
||||
fusedBias.fix(VARP::CONSTANT);
|
||||
}
|
||||
|
||||
auto x = _Convert(inputs[0], NC4HW4);
|
||||
|
||||
int8_t inputZeroPoint, outputZeroPoint;
|
||||
{
|
||||
std::vector<int> dims = {x->getInfo()->dim[1]};
|
||||
auto dimVar = _Const(dims.data(), {1}, NCHW, halide_type_of<int32_t>());
|
||||
VARP channelScale = _Reciprocal(_Fill(dimVar, mInputScale));
|
||||
x = _FloatToInt8(x, channelScale, -127, 127);// TODO add clamp
|
||||
VARP channelScale, zeroPoint;
|
||||
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mInputMin, mInputMax, mInputClampValue);
|
||||
mInputScale = scaleAndZeroPoint.first;
|
||||
mInputZeroPoint = scaleAndZeroPoint.second;
|
||||
|
||||
// always PerTensor
|
||||
channelScale = _Reciprocal(mInputScale);
|
||||
zeroPoint = _Cast<int8_t>(mInputZeroPoint);
|
||||
|
||||
inputZeroPoint = zeroPoint->readMap<int8_t>()[0];
|
||||
|
||||
x = _FloatToInt8(x, channelScale, -int8_t(mInputClampValue->readMap<float>()[0]), int8_t(mInputClampValue->readMap<float>()[0]), inputZeroPoint);
|
||||
}
|
||||
{
|
||||
VARP channelScale, zeroPoint;
|
||||
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mOutputMin, mOutputMax, mOutputClampValue);
|
||||
mOutputScale = scaleAndZeroPoint.first;
|
||||
mOutputZeroPoint = scaleAndZeroPoint.second;
|
||||
|
||||
// always PerTensor
|
||||
channelScale = mOutputScale;
|
||||
zeroPoint = _Cast<int8_t>(mOutputZeroPoint);
|
||||
|
||||
outputZeroPoint = zeroPoint->readMap<int8_t>()[0];
|
||||
}
|
||||
|
||||
std::vector<int8_t> weight;
|
||||
|
@ -855,19 +801,18 @@ public:
|
|||
std::vector<float> scale;
|
||||
{
|
||||
VARP weightScale, quanWeight, convScale;
|
||||
if (mOption.depthwise) {
|
||||
auto newWeight = fusedWeights * _Reshape(mInputScale, {-1, 1, 1, 1});
|
||||
weightScale = _Maximum(_ReduceMax(_Abs(newWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
|
||||
quanWeight = _Cast<int8_t>(_Round(newWeight * _Reciprocal(weightScale)));
|
||||
convScale = _Reshape(_Reciprocal(mOutputScale), {-1, 1, 1, 1}) * weightScale;
|
||||
} else {
|
||||
auto newWeight = fusedWeights * mInputScale;
|
||||
weightScale = _Maximum(_ReduceMax(_Abs(newWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
|
||||
quanWeight = _Cast<int8_t>(_Round(newWeight * _Reciprocal(weightScale)));
|
||||
convScale = _Reshape(_Reciprocal(mOutputScale), {-1, 1, 1, 1}) * weightScale;
|
||||
}
|
||||
auto quanBias = _Cast<int32_t>(fusedBias * _Reciprocal(weightScale));
|
||||
Variable::prepareCompute({quanBias, quanWeight, convScale});
|
||||
auto newWeight = fusedWeights * mInputScale;
|
||||
weightScale = _Maximum(_ReduceMax(_Abs(newWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
|
||||
quanWeight = _Cast<int8_t>(_Round(newWeight * _Reciprocal(weightScale)));
|
||||
convScale = _Reciprocal(mOutputScale) * weightScale;
|
||||
Variable::prepareCompute({quanWeight, convScale});
|
||||
|
||||
auto remains = _ReduceSum(_Cast<int32_t>(mInputZeroPoint) * _Cast<int32_t>(quanWeight), {1, 2, 3}, true);
|
||||
MNN_ASSERT((mOutputZeroPoint->getInfo()->dim.size() == 0) && (mOutputZeroPoint->getInfo()->size == 1)); // only support per-tensor, per-channel is removed.
|
||||
auto outputZeroPointFused = _Cast<int32_t>(_Cast<float>(mOutputZeroPoint) * _Reciprocal(convScale));
|
||||
auto quanBias = _Cast<int32_t>(fusedBias * _Reciprocal(weightScale)) - remains + outputZeroPointFused;
|
||||
Variable::prepareCompute({quanBias});
|
||||
|
||||
{
|
||||
auto info = quanWeight->getInfo();
|
||||
weight.resize(info->size);
|
||||
|
@ -888,14 +833,13 @@ public:
|
|||
}
|
||||
bool relu = mActivation == NN::None ? false : true;
|
||||
res = _Conv(std::move(weight), std::move(bias), std::move(scale), _Convert(x, NC4HW4), mOption.channel,
|
||||
mOption.kernelSize, mOption.padMode, mOption.stride, mOption.dilate, mGroup, mOption.pads, relu, 0, 0, -int8_t(mClampValue->readMap<float>()[0]), int8_t(mClampValue->readMap<float>()[0]), false);
|
||||
mOption.kernelSize, mOption.padMode, mOption.stride, mOption.dilate, mGroup, mOption.pads, relu,
|
||||
inputZeroPoint, outputZeroPoint,
|
||||
-int8_t(mOutputClampValue->readMap<float>()[0]), int8_t(mOutputClampValue->readMap<float>()[0]), mAccumulateToInt16);
|
||||
res->setName(name());
|
||||
{
|
||||
std::vector<int> dims = {res->getInfo()->dim[1]};
|
||||
auto dimVar = _Const(dims.data(), {1}, NCHW, halide_type_of<int32_t>());
|
||||
VARP channelScale = _Fill(dimVar, mOutputScale);
|
||||
res = _Int8ToFloat(res, channelScale);
|
||||
}
|
||||
|
||||
// always PerTensor
|
||||
res = _Int8ToFloat(res, mOutputScale, outputZeroPoint);
|
||||
}
|
||||
|
||||
return {res};
|
||||
|
@ -915,12 +859,23 @@ private:
|
|||
module->mBias = ctx->getOrClone(mBias);
|
||||
module->mActivation = mActivation;
|
||||
module->mBits = mBits;
|
||||
module->mLimit = mLimit;
|
||||
module->mLimitScale = ctx->getOrClone(mLimitScale);
|
||||
module->mInputScalePos = mInputScalePos;
|
||||
module->mOutputScalePos = mOutputScalePos;
|
||||
module->mWeightClampValue = ctx->getOrClone(mWeightClampValue);
|
||||
module->mInputScale = ctx->getOrClone(mInputScale);
|
||||
module->mOutputScale = ctx->getOrClone(mOutputScale);
|
||||
module->mClampValue = ctx->getOrClone(mClampValue);
|
||||
module->mInputMin = ctx->getOrClone(mInputMin);
|
||||
module->mInputMax = ctx->getOrClone(mInputMax);
|
||||
module->mOutputMin = ctx->getOrClone(mOutputMin);
|
||||
module->mOutputMax = ctx->getOrClone(mOutputMax);
|
||||
module->mInputZeroPoint = ctx->getOrClone(mInputZeroPoint);
|
||||
module->mOutputZeroPoint = ctx->getOrClone(mOutputZeroPoint);
|
||||
module->mInputMinPos = mInputMinPos;
|
||||
module->mInputMaxPos = mInputMaxPos;
|
||||
module->mOutputMinPos = mOutputMinPos;
|
||||
module->mOutputMaxPos = mOutputMaxPos;
|
||||
module->mInputClampValue = ctx->getOrClone(mInputClampValue);
|
||||
module->mOutputClampValue = ctx->getOrClone(mOutputClampValue);
|
||||
module->mMomentum = mMomentum;
|
||||
module->mFeatureScaleStatMethod = mFeatureScaleStatMethod;
|
||||
module->mScaleUpdateMethod = mScaleUpdateMethod;
|
||||
|
@ -939,15 +894,27 @@ private:
|
|||
NN::ActivationFunctionType mActivation = NN::ActivationFunctionType::None;
|
||||
std::shared_ptr<Module> mBatchNorm = nullptr;
|
||||
int mBits;
|
||||
float mLimit;
|
||||
VARP mLimitScale;
|
||||
int mInputScalePos = -1;
|
||||
int mOutputScalePos = -1;
|
||||
Express::VARP mWeightClampValue;
|
||||
VARP mInputScale = nullptr;
|
||||
VARP mOutputScale = nullptr;
|
||||
VARP mClampValue;
|
||||
VARP mInputMin = nullptr;
|
||||
VARP mInputMax = nullptr;
|
||||
VARP mOutputMin = nullptr;
|
||||
VARP mOutputMax = nullptr;
|
||||
VARP mInputZeroPoint = nullptr;
|
||||
VARP mOutputZeroPoint = nullptr;
|
||||
int mInputMinPos = -1;
|
||||
int mInputMaxPos = -1;
|
||||
int mOutputMinPos = -1;
|
||||
int mOutputMaxPos = -1;
|
||||
VARP mInputClampValue;
|
||||
VARP mOutputClampValue;
|
||||
float mMomentum = 0.99f;
|
||||
NN::FeatureScaleStatMethod mFeatureScaleStatMethod;
|
||||
NN::ScaleUpdateMethod mScaleUpdateMethod;
|
||||
bool mAccumulateToInt16 = false;
|
||||
};
|
||||
|
||||
Module* NN::ConvBNReluFused(std::vector<std::shared_ptr<Module> > modules,
|
||||
|
@ -967,4 +934,4 @@ Module* NN::ConvInt8(const ConvParameters& para, int bits, NN::FeatureScaleStatM
|
|||
}
|
||||
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
||||
} // namespace MNN
|
||||
|
|
|
@ -425,6 +425,7 @@ void PipelineModule::_createSubGraph(const MNN::Net* net, const Module::Config*
|
|||
std::unique_ptr<NetT> _tempNet(new NetT);
|
||||
_tempNet->oplists = std::move(_tempInfo->nodes);
|
||||
_tempNet->tensorName = std::move(_tempInfo->tensors);
|
||||
_tempNet->extraTensorDescribe = std::move(_tempInfo->extraTensorDescribe);
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = Net::Pack(builder, _tempNet.get());
|
||||
builder.Finish(offset);
|
||||
|
@ -598,6 +599,13 @@ static Module* _createSubModule(const MNN::Net* net, const SubModuleInfo& info,
|
|||
for (int i=0; i<net->tensorName()->size(); ++i) {
|
||||
_tempNet->tensorName[i] = net->tensorName()->GetAsString(i)->str();
|
||||
}
|
||||
// Copy Tensor Describe for quant model
|
||||
if (net->extraTensorDescribe()) {
|
||||
_tempNet->extraTensorDescribe.resize(net->extraTensorDescribe()->size());
|
||||
for (int i=0; i<net->extraTensorDescribe()->size(); ++i) {
|
||||
_tempNet->extraTensorDescribe[i].reset(net->extraTensorDescribe()->Get(i)->UnPack());
|
||||
}
|
||||
}
|
||||
// Create Input node
|
||||
std::vector<std::string> inputNames;
|
||||
for (auto index : info.inputs) {
|
||||
|
@ -727,6 +735,12 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
// Make Stack, first: origin, second: new
|
||||
std::map<int, int> stackMap;
|
||||
int stackIndex = 0;
|
||||
for (auto index : inputIndexesVec) {
|
||||
if (stackMap.find(index) == stackMap.end()) {
|
||||
stackMap.insert(std::make_pair(index, stackIndex));
|
||||
stackIndex++;
|
||||
}
|
||||
}
|
||||
for (auto& m : subModulesInfo) {
|
||||
for (auto index : m.inputs) {
|
||||
if (stackMap.find(index) == stackMap.end()) {
|
||||
|
@ -742,6 +756,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
}
|
||||
}
|
||||
result->mStackSize = stackMap.size();
|
||||
MNN_ASSERT(result->mStackSize > 0);
|
||||
for (int i=0; i<subModulesInfo.size(); ++i) {
|
||||
auto& info = subModulesInfo[i];
|
||||
// Reindex stack index
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include "Utils.hpp"
|
||||
#include "core/MNNMemoryUtils.h"
|
||||
#include "core/Schedule.hpp"
|
||||
#include "core/Session.hpp"
|
||||
#include "core/TensorUtils.hpp"
|
||||
|
||||
|
@ -24,15 +23,60 @@ static std::shared_ptr<BufferStorage> preRearrangeWeights( // NOLINT
|
|||
const MNN::Net* net, std::map<const Op*, std::shared_ptr<Execution>>& cache, Backend* backend) {
|
||||
std::unique_ptr<MNN::NetT> net_table(net->UnPack());
|
||||
std::map<int, std::shared_ptr<Execution>> exeCache;
|
||||
bool isQuantModel = !net_table->extraTensorDescribe.empty();
|
||||
std::vector<TensorQuantInfoT*> quantInfos;
|
||||
std::vector<std::unique_ptr<Tensor>> inputTensors;
|
||||
if (isQuantModel) {
|
||||
quantInfos.resize(net_table->tensorName.size(), nullptr);
|
||||
for (auto& tensorDes : net_table->extraTensorDescribe) {
|
||||
quantInfos[tensorDes->index] = tensorDes->quantInfo.get();
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < net->oplists()->size(); ++i) {
|
||||
auto op = net->oplists()->Get(i);
|
||||
auto op_table = net_table->oplists[i].get();
|
||||
if (op->inputIndexes() == nullptr || op->inputIndexes()->size() != 1) {
|
||||
continue;
|
||||
}
|
||||
switch (op->type()) {
|
||||
case MNN::OpType_DepthwiseConvInt8:
|
||||
case MNN::OpType_ConvInt8:
|
||||
case MNN::OpType_ConvolutionDepthwise:
|
||||
case MNN::OpType_Convolution: {
|
||||
std::shared_ptr<Execution> exe(backend->onCreate({}, {}, op));
|
||||
std::shared_ptr<Execution> exe;
|
||||
if (isQuantModel) {
|
||||
int inputIdx = op->inputIndexes()->Get(0);
|
||||
auto inputTensor = Tensor::create({1}, halide_type_of<float>());
|
||||
inputTensors.emplace_back(inputTensor);
|
||||
auto& inputQuantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
|
||||
if (quantInfos[inputIdx]) {
|
||||
inputQuantAttr.reset(new QuantAttr);
|
||||
inputQuantAttr->scale = quantInfos[inputIdx]->scale;
|
||||
inputQuantAttr->min = quantInfos[inputIdx]->min;
|
||||
inputQuantAttr->max = quantInfos[inputIdx]->max;
|
||||
inputQuantAttr->zero = quantInfos[inputIdx]->zero;
|
||||
} else {
|
||||
inputQuantAttr.reset();
|
||||
}
|
||||
int outputIdx = op->inputIndexes()->Get(0);
|
||||
auto outputTensor = Tensor::create({1}, halide_type_of<float>());
|
||||
inputTensors.emplace_back(outputTensor);
|
||||
auto& outputQuantAttr = TensorUtils::getDescribe(outputTensor)->quantAttr;
|
||||
if (quantInfos[outputIdx]) {
|
||||
outputQuantAttr.reset(new QuantAttr);
|
||||
outputQuantAttr->scale = quantInfos[outputIdx]->scale;
|
||||
outputQuantAttr->min = quantInfos[outputIdx]->min;
|
||||
outputQuantAttr->max = quantInfos[outputIdx]->max;
|
||||
outputQuantAttr->zero = quantInfos[outputIdx]->zero;
|
||||
} else {
|
||||
outputQuantAttr.reset();
|
||||
}
|
||||
if (inputQuantAttr && outputQuantAttr && op->main_as_Convolution2D()->quanParameter()) {
|
||||
exe.reset(backend->onCreate({inputTensor}, {outputTensor}, op));
|
||||
}
|
||||
} else {
|
||||
exe.reset(backend->onCreate({}, {}, op));
|
||||
}
|
||||
if (nullptr == exe) {
|
||||
break;
|
||||
}
|
||||
|
@ -70,9 +114,6 @@ static std::shared_ptr<BufferStorage> preRearrangeWeights( // NOLINT
|
|||
auto op = net->oplists()->Get(iter.first);
|
||||
cache.insert(std::make_pair(op, iter.second));
|
||||
}
|
||||
for (int i = 0; i < net->oplists()->size(); ++i) {
|
||||
auto op = net->oplists()->Get(i);
|
||||
}
|
||||
return net_storage;
|
||||
}
|
||||
|
||||
|
@ -129,18 +170,47 @@ StaticModule::StaticModule(const void* buffer, size_t length, const std::vector<
|
|||
if (mResource->mOutputFromTensor.empty()) {
|
||||
return;
|
||||
}
|
||||
auto rt = Express::ExecutorScope::Current()->getRuntime();
|
||||
|
||||
RuntimeInfo rt;
|
||||
if (moduleconfig.backend == nullptr) {
|
||||
rt = Express::ExecutorScope::Current()->getRuntime();
|
||||
} else {
|
||||
ScheduleConfig sche_config;
|
||||
sche_config.type = moduleconfig.backend->type;
|
||||
sche_config.backendConfig = moduleconfig.backend->config;
|
||||
rt = Interpreter::createRuntime(std::vector<ScheduleConfig>({sche_config}));
|
||||
}
|
||||
// TODO: Add Config
|
||||
ScheduleConfig config;
|
||||
config.numThread = 1;
|
||||
config.type = rt.first.begin()->first;
|
||||
config.saveTensors = outputs;
|
||||
auto scheduleInfo = Schedule::schedule(GetNet(buffer), {config});
|
||||
mResource->mConfig.numThread = 1;
|
||||
mResource->mConfig.type = rt.first.begin()->first;
|
||||
mResource->mConfig.path.mode = ScheduleConfig::Path::Mode::Tensor;
|
||||
mResource->mConfig.path.outputs = outputs;
|
||||
mResource->mConfig.saveTensors = outputs;
|
||||
mResource->mConfig.path.inputs = inputs;
|
||||
auto scheduleInfo = Schedule::schedule(GetNet(buffer), {mResource->mConfig});
|
||||
#ifdef MNN_EXPR_ENABLE_PROFILER
|
||||
Interpreter::SessionMode callBackMode = Interpreter::Session_Debug;
|
||||
#else
|
||||
Interpreter::SessionMode callBackMode = Interpreter::Session_Release;
|
||||
#endif
|
||||
auto isUsedContent = [&scheduleInfo](const Tensor* t) {
|
||||
const auto& infos = scheduleInfo.pipelineInfo[0].second;
|
||||
for (auto info : infos) {
|
||||
auto needInputs = SizeComputer::needInputContent(info.op, info.inputs.size());
|
||||
for (auto inputIdx : needInputs) {
|
||||
if (inputIdx < info.inputs.size() && info.inputs[inputIdx] == t) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
std::set<Tensor*> useContentInputs;
|
||||
for (const auto& iter : scheduleInfo.inputTensors) {
|
||||
if (isUsedContent(iter.second)) {
|
||||
useContentInputs.insert(iter.second);
|
||||
}
|
||||
}
|
||||
Interpreter::SessionMode inputMode =
|
||||
mResource->mShapeFix ? Interpreter::Session_Input_Inside : Interpreter::Session_Input_User;
|
||||
mSession.reset(new Session(std::move(scheduleInfo), callBackMode, inputMode, std::move(rt)));
|
||||
|
@ -151,6 +221,9 @@ StaticModule::StaticModule(const void* buffer, size_t length, const std::vector<
|
|||
mInputTensors.resize(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
mInputTensors[i] = mSession->getInput(inputs[i].c_str());
|
||||
if (useContentInputs.find(mInputTensors[i]) != useContentInputs.end()) {
|
||||
mResource->mUseContentInputs.insert(i);
|
||||
}
|
||||
}
|
||||
mOutputTensors.resize(mResource->mOutputFromTensor.size());
|
||||
for (int i = 0; i < mResource->mOutputFromTensor.size(); ++i) {
|
||||
|
@ -177,22 +250,18 @@ std::vector<Express::VARP> StaticModule::onForward(const std::vector<Express::VA
|
|||
if (nullptr == mInputTensors[i]) {
|
||||
continue;
|
||||
}
|
||||
auto info = inputs[i]->getInfo();
|
||||
mInputTensors[i]->buffer().type = info->type;
|
||||
auto des = TensorUtils::getDescribe(mInputTensors[i]);
|
||||
if (info->order == Express::NCHW) {
|
||||
des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
|
||||
auto exprInfo = inputs[i]->expr();
|
||||
auto inside = exprInfo.first->inside();
|
||||
auto inputTensor = inside->mOutputTensors[exprInfo.second];
|
||||
if (nullptr != inside->mCache) {
|
||||
inputTensor = Executor::getOutput(inside->mCache.get(), inside->mCacheOffset);
|
||||
}
|
||||
if (info->order == Express::NHWC) {
|
||||
des->dimensionFormat = MNN_DATA_FORMAT_NHWC;
|
||||
}
|
||||
if (info->order == Express::NC4HW4) {
|
||||
des->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
|
||||
}
|
||||
if (info->tensorArrayAttr != nullptr) {
|
||||
des->tensorArrayAttr = info->tensorArrayAttr;
|
||||
}
|
||||
resizeTensor(mInputTensors[i], info->dim);
|
||||
auto srcDes = TensorUtils::getDescribe(inputTensor);
|
||||
auto des = TensorUtils::getDescribe(mInputTensors[i]);
|
||||
des->dimensionFormat = srcDes->dimensionFormat;
|
||||
des->tensorArrayAttr = srcDes->tensorArrayAttr;
|
||||
mInputTensors[i]->buffer().type = inputTensor->buffer().type;
|
||||
resizeTensor(mInputTensors[i], inputTensor->shape());
|
||||
}
|
||||
if (!mResource->mShapeFix) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
|
@ -202,13 +271,14 @@ std::vector<Express::VARP> StaticModule::onForward(const std::vector<Express::VA
|
|||
auto srcPtr = (uint8_t*)inputs[i]->readMap<void>();
|
||||
if (srcPtr != mInputTensors[i]->buffer().host) {
|
||||
mInputTensors[i]->buffer().host = srcPtr;
|
||||
mSession->setNeedResize();
|
||||
mSession->setNeedMalloc();
|
||||
if (mResource->mUseContentInputs.find(i) != mResource->mUseContentInputs.end()) {
|
||||
mSession->setNeedResize();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mSession->getNeedResize()) {
|
||||
mSession->resize();
|
||||
}
|
||||
mSession->resize();
|
||||
if (mResource->mShapeFix) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (nullptr == mInputTensors[i]) {
|
||||
|
@ -247,34 +317,22 @@ std::vector<Express::VARP> StaticModule::onForward(const std::vector<Express::VA
|
|||
#endif
|
||||
for (int i = 0; i < mOutputTensors.size(); ++i) {
|
||||
auto currentTensor = mOutputTensors[i];
|
||||
auto& quantAttr = TensorUtils::getDescribe(currentTensor)->quantAttr;
|
||||
bool isQuant = (quantAttr && TensorUtils::DataTypeToHalideType(quantAttr->type) == currentTensor->getType());
|
||||
// copy the data when reused as input tensor with data;
|
||||
if (currentTensor->elementSize() > 0 && (mResource->mReusedTensors.find(mResource->mOutputFromTensor[i]) != mResource->mReusedTensors.end() || mResource->mCopyOutput)) {
|
||||
std::shared_ptr<Tensor> tmpTensor(new Tensor(currentTensor, currentTensor->getDimensionType(), false));
|
||||
if (currentTensor->elementSize() > 0 && (mResource->mReusedTensors.find(mResource->mOutputFromTensor[i]) != mResource->mReusedTensors.end() || mResource->mCopyOutput || isQuant)) {
|
||||
auto tmpTensor = new Tensor(currentTensor, currentTensor->getDimensionType(), false);
|
||||
tmpTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(tmpTensor->size(), MNN_MEMORY_ALIGN_DEFAULT);
|
||||
auto des = TensorUtils::getDescribe(mOutputTensors[i]);
|
||||
if (nullptr != des->backend) {
|
||||
currentTensor->copyToHostTensor(tmpTensor.get());
|
||||
currentTensor->copyToHostTensor(tmpTensor);
|
||||
} else {
|
||||
MNNCPUCopyBuffer(currentTensor, tmpTensor.get());
|
||||
}
|
||||
Express::Variable::Info info;
|
||||
info.dim = tmpTensor->shape();
|
||||
info.type = tmpTensor->getType();
|
||||
auto format = des->dimensionFormat;
|
||||
info.order = Express::NHWC;
|
||||
if (format == MNN_DATA_FORMAT_NCHW) {
|
||||
info.order = Express::NCHW;
|
||||
} else if (format == MNN_DATA_FORMAT_NC4HW4) {
|
||||
info.order = Express::NC4HW4;
|
||||
}
|
||||
// if this output tensor is TensorArray, copy attr
|
||||
if (des->tensorArrayAttr != nullptr) {
|
||||
info.tensorArrayAttr = des->tensorArrayAttr;
|
||||
MNNCPUCopyBuffer(currentTensor, tmpTensor);
|
||||
}
|
||||
TensorUtils::getDescribe(tmpTensor)->dimensionFormat = des->dimensionFormat;
|
||||
TensorUtils::getDescribe(tmpTensor)->tensorArrayAttr = des->tensorArrayAttr;
|
||||
outputs[mResource->mOutputFromTensor[i]] =
|
||||
Express::Variable::create(Express::Expr::create(std::move(info), tmpTensor->host<void>(),
|
||||
Express::VARP::CONSTANT, Expr::MemoryType::MOVE),
|
||||
0);
|
||||
Express::Variable::create(Express::Expr::create(tmpTensor, true), 0);
|
||||
} else {
|
||||
outputs[mResource->mOutputFromTensor[i]] = Express::Variable::create(Express::Expr::create(mOutputTensors[i]));
|
||||
}
|
||||
|
@ -293,11 +351,7 @@ Module* StaticModule::clone(CloneContext* ctx) const {
|
|||
return this->cloneBaseTo(ctx, module);
|
||||
}
|
||||
auto rt = Express::ExecutorScope::Current()->getRuntime();
|
||||
ScheduleConfig config;
|
||||
config.numThread = 1;
|
||||
config.type = rt.first.begin()->first;
|
||||
config.saveTensors = mResource->mOutputs;
|
||||
auto scheduleInfo = Schedule::schedule(GetNet(mResource->mNetStorage->buffer()), {config});
|
||||
auto scheduleInfo = Schedule::schedule(GetNet(mResource->mNetStorage->buffer()), {mResource->mConfig});
|
||||
#ifdef MNN_EXPR_ENABLE_PROFILER
|
||||
Interpreter::SessionMode callBackMode = Interpreter::Session_Debug;
|
||||
#else
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
|
||||
#include <set>
|
||||
#include <MNN/expr/Module.hpp>
|
||||
#include "core/Schedule.hpp"
|
||||
|
||||
namespace MNN {
|
||||
class Session;
|
||||
class Backend;
|
||||
|
@ -40,8 +42,10 @@ private:
|
|||
std::vector<std::pair<int, int>> mOutputFromInput;
|
||||
// the outputs will be used as inputs
|
||||
std::set<int> mReusedTensors;
|
||||
std::set<int> mUseContentInputs;
|
||||
std::shared_ptr<BufferStorage> mNetStorage;
|
||||
bool mCopyOutput = false;
|
||||
ScheduleConfig mConfig;
|
||||
};
|
||||
std::shared_ptr<Session> mSession;
|
||||
std::vector<Tensor*> mInputTensors;
|
||||
|
|
|
@ -133,11 +133,20 @@ public:
|
|||
}
|
||||
static Tensor* createImageTensor(halide_type_t type, int w, int h, int bpp, void* p = nullptr);
|
||||
|
||||
/**
|
||||
* @brief set padding value when wrap=ZERO.
|
||||
* @param value padding value.
|
||||
* @return void.
|
||||
*/
|
||||
void setPadding(uint8_t value) {
|
||||
mPaddingValue = value;
|
||||
}
|
||||
private:
|
||||
ImageProcess(const Config& config);
|
||||
Matrix mTransform;
|
||||
Matrix mTransformInvert;
|
||||
Inside* mInside;
|
||||
uint8_t mPaddingValue = 0;
|
||||
};
|
||||
} // namespace CV
|
||||
} // namespace MNN
|
||||
|
|
|
@ -47,7 +47,7 @@ struct ScheduleConfig {
|
|||
Op = 0,
|
||||
|
||||
/**
|
||||
* Tensor Mode (NOT supported yet)
|
||||
* Tensor Mode
|
||||
* - inputs means the inputs tensors, can NOT be empty.
|
||||
* - outputs means the outputs tensors, can NOT be empty.
|
||||
* It will find the pipeline that compute outputs from inputs.
|
||||
|
|
|
@ -22,7 +22,6 @@ struct OpT;
|
|||
struct Op;
|
||||
struct NetT;
|
||||
class Tensor;
|
||||
struct TensorArrayAttr;
|
||||
namespace Express {
|
||||
class Variable;
|
||||
class Expr;
|
||||
|
@ -110,7 +109,6 @@ public:
|
|||
halide_type_t type;
|
||||
int size;
|
||||
void syncSize();
|
||||
std::shared_ptr<TensorArrayAttr> tensorArrayAttr;
|
||||
};
|
||||
const std::string& name() const;
|
||||
void setName(const std::string& name);
|
||||
|
@ -181,7 +179,7 @@ public:
|
|||
MOVE,
|
||||
REF
|
||||
};
|
||||
static EXPRP create(Tensor* tensor);
|
||||
static EXPRP create(Tensor* tensor, bool own = false);
|
||||
|
||||
static EXPRP create(Variable::Info&& info, const void* ptr, VARP::InputType type, MemoryType copy = COPY);
|
||||
static EXPRP create(const OpT* op, std::vector<VARP> inputs, int outputSize = 1);
|
||||
|
@ -240,7 +238,7 @@ private:
|
|||
static void _addLinkForInputs(EXPRP expr);
|
||||
|
||||
Expr(int outputSize);
|
||||
Expr(Tensor* tensor);
|
||||
Expr(Tensor* tensor, bool own = false);
|
||||
|
||||
friend class Variable;
|
||||
friend class VARP;
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/MNNForwardType.h>
|
||||
|
||||
namespace MNN {
|
||||
namespace Express {
|
||||
|
@ -47,6 +48,11 @@ public:
|
|||
void setParameter(Express::VARP parameter, int index);
|
||||
static Module* createEmpty(const std::vector<Express::VARP>& parameters);
|
||||
|
||||
struct BackendInfo {
|
||||
MNNForwardType type = MNN_FORWARD_CPU;
|
||||
BackendConfig* config = nullptr;
|
||||
};
|
||||
|
||||
struct Config {
|
||||
// Load module as dynamic, default static
|
||||
bool dynamic = false;
|
||||
|
@ -57,6 +63,8 @@ public:
|
|||
// The weights will be rearranged in a general way, so the best implementation
|
||||
// may not be adopted if `rearrange` is enabled.
|
||||
bool rearrange = false;
|
||||
|
||||
BackendInfo* backend = nullptr;
|
||||
};
|
||||
static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr);
|
||||
static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Config* config = nullptr);
|
||||
|
|
|
@ -73,7 +73,6 @@ public:
|
|||
static Module* ConvInt8(const ConvParameters& parameters, int bits,
|
||||
FeatureScaleStatMethod featureMethod = PerChannel,
|
||||
ScaleUpdateMethod method = MovingAverage);
|
||||
static Module* ConvOctave(const ConvParameters& parameters, float inFactor, float outFactor);
|
||||
static Module* Conv(const ConvParameters& parameters);
|
||||
static Module* ConvBNReluFused(std::vector<std::shared_ptr<Module> > modules,
|
||||
NN::FeatureScaleStatMethod featureScaleStatMethod = PerTensor,
|
||||
|
|
|
@ -136,12 +136,16 @@ MNN_PUBLIC VARP _Conv(std::vector<int8_t>&& weight, std::vector<int>&& bias, std
|
|||
int8_t inputZeroPoint, int8_t outputZeroPoint,
|
||||
int8_t minValue, int8_t maxValue, bool accumulateToInt16);
|
||||
MNN_PUBLIC VARP _CosineSimilarity(VARP input0, VARP input1, VARP inputDim);
|
||||
|
||||
enum GridSamplePaddingMode {GRID_SAMPLE_PADDING_ZEROS, GRID_SAMPLE_PADDING_BORDER, GRID_SAMPLE_PADDING_REFLECTION};
|
||||
MNN_PUBLIC VARP _GridSample(VARP input, VARP grid, InterpolationMethod mode=BILINEAR, GridSamplePaddingMode paddingMode=GRID_SAMPLE_PADDING_ZEROS, bool alignCorners=false);
|
||||
MNN_PUBLIC VARP _FloatToInt8(VARP x, VARP scale, char minValue, char maxValue);
|
||||
MNN_PUBLIC VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t zeroPoint);
|
||||
MNN_PUBLIC VARP _Int8ToFloat(VARP x, VARP scale);
|
||||
MNN_PUBLIC VARP _Int8ToFloat(VARP x, VARP scale, int8_t zeroPoint);
|
||||
|
||||
MNN_PUBLIC VARP _Select(VARP select, VARP input0, VARP input1);
|
||||
MNN_PUBLIC std::vector<VARP> _TopKV2(VARP input0, VARP input1);
|
||||
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
||||
|
|
|
@ -29,7 +29,7 @@ rm -rf build && mkdir build
|
|||
pushd build
|
||||
|
||||
[ -f CMakeCache.txt ] && rm CMakeCache.txt
|
||||
cmake $CMAKE_ARGS .. && make -j8
|
||||
cmake $CMAKE_ARGS .. && make -j24
|
||||
cp *.out $TOOLS_PATH
|
||||
|
||||
popd
|
||||
|
|
|
@ -31,6 +31,7 @@ cmake $CMAKE_ARGS .. && make MNN MNNTrain MNNConvert -j24
|
|||
popd
|
||||
|
||||
pushd pymnn/pip_package
|
||||
echo -e "__version__ = '$mnn_version'" > MNN/version.py
|
||||
rm -rf build && mkdir build
|
||||
rm -rf dist && mkdir dist
|
||||
rm -rf wheelhouse && mkdir wheelhouse
|
||||
|
@ -46,5 +47,5 @@ for whl in dist/*.whl; do
|
|||
auditwheel repair "$whl" --plat manylinux2014_x86_64 -w wheelhouse
|
||||
done
|
||||
cp wheelhouse/* $PACKAGE_PATH
|
||||
|
||||
rm MNN/version.py
|
||||
popd
|
||||
|
|
|
@ -34,6 +34,7 @@ cmake $CMAKE_ARGS .. && make MNN MNNTrain MNNConvert -j8
|
|||
popd
|
||||
|
||||
pushd pymnn/pip_package
|
||||
echo -e "__version__ = '$mnn_version'" > MNN/version.py
|
||||
rm -rf build && mkdir build
|
||||
rm -rf dist && mkdir dist
|
||||
for env in $python_versions; do
|
||||
|
@ -41,5 +42,5 @@ for env in $python_versions; do
|
|||
python build_wheel.py --version $mnn_version
|
||||
done
|
||||
cp dist/* $PACKAGE_PATH
|
||||
|
||||
rm MNN/version.py
|
||||
popd
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
# |--- Static
|
||||
|
||||
Param(
|
||||
[Parameter(Mandatory=$true)][String]$version,
|
||||
[Parameter(Mandatory=$true)][String]$pyc_env,
|
||||
[Parameter(Mandatory=$true)][String]$mnn_path,
|
||||
[Parameter(Mandatory=$true)][String]$path,
|
||||
|
@ -62,6 +63,7 @@ popd
|
|||
pyenv global $pyc_env
|
||||
python -c "import compileall; compileall.compile_dir('./pymnn_pyc_tmp', force=True)"
|
||||
Remove-Item .\pymnn_pyc_tmp -Include *.py -Recurse
|
||||
Set-Content -Path pymnn_pyc_tmp\version.py -Value "__version__ = '$version'"
|
||||
cp -r .\pymnn_pyc_tmp\* $PACKAGE_PATH\wrapper -Force
|
||||
rm -r -force pymnn_pyc_tmp
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ if ($opencl) {
|
|||
$CMAKE_ARGS = "$CMAKE_ARGS -DMNN_OPENCL=ON"
|
||||
}
|
||||
|
||||
Remove-Item build -Recurse -ErrorAction Ignore
|
||||
#Remove-Item build -Recurse -ErrorAction Ignore
|
||||
mkdir build
|
||||
pushd build
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ ninja MNN MNNTrain MNNConvert
|
|||
popd
|
||||
|
||||
pushd pymnn/pip_package
|
||||
Set-Content -Path MNN/version.py -Value "__version__ = '$version'"
|
||||
Remove-Item dist -Recurse -ErrorAction Ignore
|
||||
Remove-Item build -Recurse -ErrorAction Ignore
|
||||
mkdir dist
|
||||
|
@ -41,4 +42,5 @@ Foreach ($env in $python_versions) {
|
|||
Invoke-Expression "python build_wheel.py $ARGS"
|
||||
}
|
||||
cp dist/* $PACKAGE_PATH
|
||||
Remove-Item MNN/version.py -ErrorAction Ignore
|
||||
popd
|
|
@ -8,6 +8,9 @@ cmake ../../../ \
|
|||
-DANDROID_NATIVE_API_LEVEL=android-14 \
|
||||
-DANDROID_TOOLCHAIN=clang \
|
||||
-DMNN_USE_LOGCAT=false \
|
||||
-DMNN_USE_SSE=OFF \
|
||||
-DMNN_SUPPORT_BF16=OFF \
|
||||
-DMNN_BUILD_TEST=ON \
|
||||
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
|
||||
-DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3
|
||||
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Release compile work until ndk-r21e (clang 9.0.9svn), Debug compile work until ndk-r22 (clang 11.0.5)
|
||||
# https://github.com/android/ndk/wiki/Changelog-r22#changes Issues 1303
|
||||
# https://github.com/android/ndk/wiki/Changelog-r21#r21e Issues 1248
|
||||
# export ANDROID_NDK=/path/to/ndk-r21e
|
||||
|
||||
cmake ../../../ \
|
||||
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DANDROID_ABI="armeabi-v7a" \
|
||||
-DANDROID_STL=c++_static \
|
||||
-DANDROID_NATIVE_API_LEVEL=android-18 \
|
||||
-DANDROID_TOOLCHAIN=clang \
|
||||
-DMNN_USE_LOGCAT=false \
|
||||
-DMNN_USE_SSE=OFF \
|
||||
-DMNN_SUPPORT_BF16=OFF \
|
||||
-DMNN_BUILD_TEST=ON \
|
||||
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
|
||||
-DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3 \
|
||||
-DMNN_ARM82=ON \
|
||||
-DMNN_BUILD_BENCHMARK=ON
|
||||
|
||||
make -j8
|
|
@ -6,6 +6,9 @@ cmake ../../../ \
|
|||
-DANDROID_STL=c++_static \
|
||||
-DMNN_USE_LOGCAT=false \
|
||||
-DMNN_BUILD_BENCHMARK=ON \
|
||||
-DMNN_USE_SSE=OFF \
|
||||
-DMNN_SUPPORT_BF16=OFF \
|
||||
-DMNN_BUILD_TEST=ON \
|
||||
-DANDROID_NATIVE_API_LEVEL=android-21 \
|
||||
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
|
||||
-DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
#!/bin/bash
|
||||
make -j16
|
||||
adb push ./libMNN.so /data/local/tmp/MNN/libMNN.so
|
||||
adb push ./libMNN_CL.so /data/local/tmp/MNN/libMNN_CL.so
|
||||
adb push ./libMNN_Vulkan.so /data/local/tmp/MNN/libMNN_Vulkan.so
|
||||
adb push ./libMNN_GL.so /data/local/tmp/MNN/libMNN_GL.so
|
||||
adb push ./libMNN_Express.so /data/local/tmp/MNN/libMNN_Express.so
|
||||
adb push ./MNNV2Basic.out /data/local/tmp/MNN/MNNV2Basic.out
|
||||
adb shell "cd /data/local/tmp/MNN && rm -r output"
|
||||
adb shell "cd /data/local/tmp/MNN && mkdir output"
|
||||
adb push ./unitTest.out /data/local/tmp/MNN/unitTest.out
|
||||
adb push ./testModel.out /data/local/tmp/MNN/testModel.out
|
||||
adb push ./testModelWithDescrisbe.out /data/local/tmp/MNN/testModelWithDescrisbe.out
|
||||
adb push ./backendTest.out /data/local/tmp/MNN/backendTest.out
|
||||
adb push ./timeProfile.out /data/local/tmp/MNN/timeProfile.out
|
||||
DIR=MNN
|
||||
|
||||
adb push ./train.out /data/local/tmp/MNN/train.out
|
||||
adb push ./benchmark.out /data/local/tmp/MNN/benchmark.out
|
||||
adb push ./benchmarkExprModels.out /data/local/tmp/MNN/benchmarkExprModels.out
|
||||
adb push ./run_test.out /data/local/tmp/MNN/run_test.out
|
||||
make -j16
|
||||
adb push ./libMNN.so /data/local/tmp/$DIR/libMNN.so
|
||||
adb push ./libMNN_CL.so /data/local/tmp/$DIR/libMNN_CL.so
|
||||
adb push ./libMNN_Vulkan.so /data/local/tmp/$DIR/libMNN_Vulkan.so
|
||||
adb push ./libMNN_GL.so /data/local/tmp/$DIR/libMNN_GL.so
|
||||
adb push ./libMNN_Express.so /data/local/tmp/$DIR/libMNN_Express.so
|
||||
adb push ./MNNV2Basic.out /data/local/tmp/$DIR/MNNV2Basic.out
|
||||
adb shell "cd /data/local/tmp/$DIR && rm -r output"
|
||||
adb shell "cd /data/local/tmp/$DIR && mkdir output"
|
||||
adb push ./unitTest.out /data/local/tmp/$DIR/unitTest.out
|
||||
adb push ./testModel.out /data/local/tmp/$DIR/testModel.out
|
||||
adb push ./testModelWithDescrisbe.out /data/local/tmp/$DIR/testModelWithDescrisbe.out
|
||||
adb push ./backendTest.out /data/local/tmp/$DIR/backendTest.out
|
||||
adb push ./timeProfile.out /data/local/tmp/$DIR/timeProfile.out
|
||||
|
||||
adb push ./train.out /data/local/tmp/$DIR/train.out
|
||||
adb push ./benchmark.out /data/local/tmp/$DIR/benchmark.out
|
||||
adb push ./benchmarkExprModels.out /data/local/tmp/$DIR/benchmarkExprModels.out
|
||||
adb push ./run_test.out /data/local/tmp/$DIR/run_test.out
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,3 +1,5 @@
|
|||
# The CMakeLists.txt be used for PC (Windows, Mac, Linux) and Android
|
||||
|
||||
cmake_minimum_required(VERSION 3.4.1)
|
||||
project(mnnpybridge)
|
||||
|
||||
|
@ -9,6 +11,7 @@ option(MNN_OPENGL "Enable OpenGL" OFF)
|
|||
option(MNN_VULKAN "Enable Vulkan" OFF)
|
||||
option(MNN_CUDA "Enable CUDA" OFF)
|
||||
option(MNN_TENSORRT "Enable TensorRT" OFF)
|
||||
option(MNN_HIAI "Enable Huawei NPU" OFF)
|
||||
option(PYMNN_USE_ALINNPYTHON "based on AliNNPython" ON)
|
||||
option(PYMNN_RUNTIME_CHECK_VM "AliNNPython version (new/old) can be check on runtime" ON)
|
||||
option(PYMNN_NEW_PYTHON "AliNNPython new version (when PYMNN_RUNTIME_CHECK_VM=OFF)" ON)
|
||||
|
@ -35,6 +38,15 @@ endif()
|
|||
if(MNN_VULKAN)
|
||||
target_compile_definitions(mnnpybridge PRIVATE MNN_VULKAN)
|
||||
endif()
|
||||
if(MNN_CUDA)
|
||||
target_compile_definitions(mnnpybridge PRIVATE MNN_CUDA)
|
||||
endif()
|
||||
if(MNN_TENSORRT)
|
||||
target_compile_definitions(mnnpybridge PRIVATE MNN_TENSORRT)
|
||||
endif()
|
||||
if(MNN_HIAI)
|
||||
target_compile_definitions(mnnpybridge PRIVATE MNN_HIAI)
|
||||
endif()
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_USE_ALINNPYTHON)
|
||||
endif()
|
||||
|
@ -81,53 +93,66 @@ else()
|
|||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-stack-protector -std=c++11 -O2 -fvisibility=hidden -fvisibility-inlines-hidden")
|
||||
endif()
|
||||
|
||||
set(DEPEND_PATH "${CMAKE_CURRENT_LIST_DIR}/3rd_party")
|
||||
set(LIB_SUBPATH "")
|
||||
if(WIN32)
|
||||
if(NOT MNN_BUILD_SHARED_LIBS)
|
||||
set(LIB_SUBPATH "Static")
|
||||
elseif(MNN_WIN_RUNTIME_MT)
|
||||
set(LIB_SUBPATH "MT")
|
||||
else()
|
||||
set(LIB_SUBPATH "MD")
|
||||
endif()
|
||||
elseif(APPLE)
|
||||
if(MNN_BUILD_SHARED_LIBS)
|
||||
set(LIB_SUBPATH "Dynamic")
|
||||
else()
|
||||
set(LIB_SUBPATH "Static")
|
||||
endif()
|
||||
endif()
|
||||
if(CMAKE_BUILD_TYPE MATCHES Debug)
|
||||
set(LIB_SUBPATH "Debug/${LIB_SUBPATH}")
|
||||
else()
|
||||
set(LIB_SUBPATH "Release/${LIB_SUBPATH}")
|
||||
endif()
|
||||
if(WIN32)
|
||||
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "4")
|
||||
set(LIB_SUBPATH "x86/${LIB_SUBPATH}")
|
||||
else()
|
||||
set(LIB_SUBPATH "x64/${LIB_SUBPATH}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_include_directories(mnnpybridge PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src ${DEPEND_PATH}/MNN/include)
|
||||
if(PYMNN_TRAIN_API)
|
||||
set(MNN_DIR ${CMAKE_CURRENT_LIST_DIR}/..)
|
||||
target_include_directories(mnnpybridge PRIVATE
|
||||
target_include_directories(mnnpybridge PRIVATE
|
||||
${MNN_DIR}/tools/train/source/grad ${MNN_DIR}/tools/train/source/optimizer ${MNN_DIR}/tools/train/source/transformer
|
||||
${MNN_DIR}/tools/train/source/data ${MNN_DIR}/schema/current ${MNN_DIR}/3rd_party/flatbuffers/include)
|
||||
endif()
|
||||
target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/MNN/lib/${LIB_SUBPATH})
|
||||
target_link_libraries(mnnpybridge PRIVATE MNN)
|
||||
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_include_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/AliNNPython/include)
|
||||
target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/AliNNPython/lib/${LIB_SUBPATH})
|
||||
target_link_libraries(mnnpybridge PRIVATE python)
|
||||
if(WIN32 OR APPLE OR CMAKE_SYSTEM_NAME MATCHES "^Linux")
|
||||
set(DEPEND_PATH "${CMAKE_CURRENT_LIST_DIR}/3rd_party")
|
||||
set(LIB_SUBPATH "")
|
||||
if(WIN32)
|
||||
if(NOT MNN_BUILD_SHARED_LIBS)
|
||||
set(LIB_SUBPATH "Static")
|
||||
elseif(MNN_WIN_RUNTIME_MT)
|
||||
set(LIB_SUBPATH "MT")
|
||||
else()
|
||||
set(LIB_SUBPATH "MD")
|
||||
endif()
|
||||
elseif(APPLE)
|
||||
if(MNN_BUILD_SHARED_LIBS)
|
||||
set(LIB_SUBPATH "Dynamic")
|
||||
else()
|
||||
set(LIB_SUBPATH "Static")
|
||||
endif()
|
||||
endif()
|
||||
if(CMAKE_BUILD_TYPE MATCHES Debug)
|
||||
set(LIB_SUBPATH "Debug/${LIB_SUBPATH}")
|
||||
else()
|
||||
set(LIB_SUBPATH "Release/${LIB_SUBPATH}")
|
||||
endif()
|
||||
if(WIN32)
|
||||
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "4")
|
||||
set(LIB_SUBPATH "x86/${LIB_SUBPATH}")
|
||||
else()
|
||||
set(LIB_SUBPATH "x64/${LIB_SUBPATH}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_include_directories(mnnpybridge PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src ${DEPEND_PATH}/MNN/include)
|
||||
target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/MNN/lib/${LIB_SUBPATH})
|
||||
target_link_libraries(mnnpybridge PRIVATE MNN)
|
||||
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_include_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/AliNNPython/include)
|
||||
target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/AliNNPython/lib/${LIB_SUBPATH})
|
||||
target_link_libraries(mnnpybridge PRIVATE python)
|
||||
endif()
|
||||
if(PYMNN_NUMPY_USABLE)
|
||||
target_include_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/numpy/include)
|
||||
target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/numpy/lib/${LIB_SUBPATH})
|
||||
target_link_libraries(mnnpybridge PRIVATE numpy_python)
|
||||
endif()
|
||||
else()
|
||||
target_include_directories(mnnpybridge PRIVATE ${MNN_DIR}/pymnn/src ${MNN_DIR}/pymnn/android/src/main/c/include)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${MNN_DIR}/pymnn/android/src/main/jniLibs/${ANDROID_ABI})
|
||||
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express)
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
|
||||
endif()
|
||||
if(PYMNN_NUMPY_USABLE)
|
||||
target_link_libraries(mnnpybridge PRIVATE numpy_python)
|
||||
endif()
|
||||
endif()
|
||||
if(PYMNN_NUMPY_USABLE)
|
||||
target_include_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/numpy/include)
|
||||
target_link_directories(mnnpybridge PRIVATE ${DEPEND_PATH}/numpy/lib/${LIB_SUBPATH})
|
||||
target_link_libraries(mnnpybridge PRIVATE numpy_python)
|
||||
endif()
|
|
@ -1,3 +1,7 @@
|
|||
# version.py be generated by scripts (build_whl.sh on PC, update_mnn_wrapper_assets.sh on mobile)
|
||||
# so don't worry about this, and don't change it, don't create version.py mannully
|
||||
from .version import __version__
|
||||
|
||||
_Slice = slice
|
||||
_Int = int
|
||||
_newaxis = None
|
||||
|
|
|
@ -2,32 +2,41 @@ _Int = int
|
|||
_Float = float
|
||||
from _mnncengine._expr import *
|
||||
import _mnncengine._expr as _F
|
||||
import numpy as np
|
||||
|
||||
_numpy_supported = False
|
||||
try:
|
||||
import numpy as np
|
||||
_numpy_supported = True
|
||||
except Exception:
|
||||
print ("Numpy not found. Using MNN without numpy.")
|
||||
|
||||
def _to_var(x, to_float=True):
|
||||
if isinstance(x, np.ndarray):
|
||||
if to_float:
|
||||
if x.dtype != np.float32:
|
||||
x = x.astype(np.float32)
|
||||
return _F.const(x, x.shape)
|
||||
if not to_float:
|
||||
if x.dtype != np.int32:
|
||||
x = x.astype(np.int32)
|
||||
return _F.const(x, x.shape, dtype=_F.int)
|
||||
elif isinstance(x, (list, tuple)) and x:
|
||||
x = np.array(x)
|
||||
if to_float:
|
||||
if x.dtype != np.float32:
|
||||
x = x.astype(np.float32)
|
||||
return _F.const(x, x.shape)
|
||||
if not to_float:
|
||||
if x.dtype != np.int32:
|
||||
x = x.astype(np.int32)
|
||||
return _F.const(x, x.shape, dtype=_F.int)
|
||||
elif isinstance(x, _Int):
|
||||
return _F.const(x, [], dtype=_F.int)
|
||||
elif isinstance(x, _Float):
|
||||
return _F.const(x, [], dtype=_F.float)
|
||||
return x
|
||||
if _numpy_supported:
|
||||
if isinstance(x, np.ndarray): # convert numpy ndarray to MNN var
|
||||
if to_float:
|
||||
if x.dtype != np.float32:
|
||||
x = x.astype(np.float32)
|
||||
return _F.const(x, x.shape)
|
||||
if not to_float:
|
||||
if x.dtype != np.int32:
|
||||
x = x.astype(np.int32)
|
||||
return _F.const(x, x.shape, dtype=_F.int)
|
||||
elif isinstance(x, (list, tuple)) and x: # convert list and tuple to MNN Var
|
||||
x = np.array(x)
|
||||
if to_float:
|
||||
if x.dtype != np.float32:
|
||||
x = x.astype(np.float32)
|
||||
return _F.const(x, x.shape)
|
||||
if not to_float:
|
||||
if x.dtype != np.int32:
|
||||
x = x.astype(np.int32)
|
||||
return _F.const(x, x.shape, dtype=_F.int)
|
||||
else: # No numpy support
|
||||
if isinstance(x, _Int):
|
||||
return _F.const(x, [], dtype=_F.int)
|
||||
elif isinstance(x, _Float):
|
||||
return _F.const(x, [], dtype=_F.float)
|
||||
return x
|
||||
def scalar(value):
|
||||
if type(value) == type(1):
|
||||
res = _F.const([value], [], _F.NCHW, _F.int)
|
||||
|
@ -56,17 +65,17 @@ def square(x):
|
|||
x = _to_var(x)
|
||||
if not isinstance(x, Var):
|
||||
raise RuntimeError("parameter x is not valid")
|
||||
return _F.square(x)
|
||||
return _F.square(x)
|
||||
def sqrt(x):
|
||||
x = _to_var(x)
|
||||
if not isinstance(x, Var):
|
||||
raise RuntimeError("parameter x is not valid")
|
||||
return _F.sqrt(x)
|
||||
return _F.sqrt(x)
|
||||
def rsqrt(x):
|
||||
x = _to_var(x)
|
||||
if not isinstance(x, Var):
|
||||
raise RuntimeError("parameter x is not valid")
|
||||
return _F.rsqrt(x)
|
||||
return _F.rsqrt(x)
|
||||
def exp(x):
|
||||
x = _to_var(x)
|
||||
if not isinstance(x, Var):
|
||||
|
@ -101,7 +110,7 @@ def acos(x):
|
|||
x = _to_var(x)
|
||||
if not isinstance(x, Var):
|
||||
raise RuntimeError("parameter x is not valid")
|
||||
return _F.acos(x)
|
||||
return _F.acos(x)
|
||||
def atan(x):
|
||||
x = _to_var(x)
|
||||
if not isinstance(x, Var):
|
||||
|
@ -231,7 +240,7 @@ def space_to_batch_nd(input, block_shape, paddings):
|
|||
if len(block_shape.shape) != 1:
|
||||
raise RuntimeError("parameter block_shape must be 1-D w/ shape [M]")
|
||||
if len(paddings.shape) != 2 or paddings.shape[-1] != 2:
|
||||
raise RuntimeError("parameter paddings must be 2-D w/ shape [M, 2]")
|
||||
raise RuntimeError("parameter paddings must be 2-D w/ shape [M, 2]")
|
||||
return _F.space_to_batch_nd(input, block_shape, paddings)
|
||||
def batch_to_space_nd(input, block_shape, crops):
|
||||
input = _to_var(input)
|
||||
|
@ -355,7 +364,7 @@ def stack(values, axis=0):
|
|||
if not isinstance(value, Var):
|
||||
raise RuntimeError("all items in parameter values must be MNN Var type")
|
||||
if value.shape != values[0].shape or value.dtype != values[0].dtype:
|
||||
raise RuntimeError("all items in parameter values must have same shape and dtype")
|
||||
raise RuntimeError("all items in parameter values must have same shape and dtype")
|
||||
return _F.stack(values, axis)
|
||||
def slice(input, starts, sizes):
|
||||
input = _to_var(input)
|
||||
|
@ -419,7 +428,7 @@ def crop(images, size, axis, offset):
|
|||
raise RuntimeError("parameter offset must be at most 2 if you want to change h/w")
|
||||
if axis == 3:
|
||||
if len(offset) != 1:
|
||||
raise RuntimeError("parameter offset must be at most 1 if you want to change w only")
|
||||
raise RuntimeError("parameter offset must be at most 1 if you want to change w only")
|
||||
return _F.crop(images, size, axis, offset)
|
||||
def crop_and_resize(image, boxes, box_ind, crop_size, method=BILINEAR, extrapolation_value=0.):
|
||||
image = _to_var(image)
|
||||
|
@ -468,12 +477,12 @@ def reshape(x, shape, original_format=NCHW):
|
|||
if not isinstance(shape, (list, tuple)):
|
||||
raise RuntimeError("parameter shape is not valid")
|
||||
new_length = 1
|
||||
skip = False
|
||||
skip = False
|
||||
for value in shape:
|
||||
if value < 0:
|
||||
skip = True
|
||||
new_length *= value
|
||||
|
||||
|
||||
if new_length != x.size and not skip:
|
||||
raise RuntimeError("parameter shape is not valid")
|
||||
return _F.reshape(x, shape, original_format)
|
||||
return _F.reshape(x, shape, original_format)
|
||||
|
|
|
@ -7,7 +7,15 @@ import _mnncengine._nn as _nn
|
|||
def load_module_from_file(file_name, input_names, output_names, **kwargs):
|
||||
dynamic = kwargs.get('dynamic', False)
|
||||
shape_mutable = kwargs.get('shape_mutable', False)
|
||||
module = _nn.load_module_from_file(input_names, output_names, file_name, dynamic, shape_mutable)
|
||||
rearrange = kwargs.get('rearrange', False)
|
||||
backend = kwargs.get('backend', _F.Backend.CPU)
|
||||
memory_mode = kwargs.get('memory_mode', _F.MemoryMode.Normal)
|
||||
power_mode = kwargs.get('power_mode', _F.PowerMode.Normal)
|
||||
precision_mode = kwargs.get('precision_mode', _F.PrecisionMode.Normal)
|
||||
thread_num = kwargs.get('thread_num', 1)
|
||||
|
||||
module = _nn.load_module_from_file(input_names, output_names, file_name, dynamic, shape_mutable, rearrange,
|
||||
backend, memory_mode, power_mode, precision_mode, thread_num)
|
||||
|
||||
return module
|
||||
|
||||
|
|
|
@ -3,38 +3,23 @@
|
|||
""" python wrapper file for mnn converter tool """
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import _tools as Tools
|
||||
|
||||
def usage():
|
||||
""" print usage info """
|
||||
print("usage: mnnconvert [-h]")
|
||||
print(" [--framework {TF,CAFFE,ONNX,TFLITE,MNN}")
|
||||
print(" [--modelFile MODELFILE]")
|
||||
print(" [--prototxt PROTOTXT]")
|
||||
print(" [--MNNModel MNNMODEL]")
|
||||
print(" [--fp16 {True,False}]")
|
||||
print(" [--weightQuantBits {num of bits for weight-only-quant, default:0, which means no quant}]")
|
||||
print(" [--weightQuantAsymmetric {True,False use asymmetric quant method for weight-only-quant, \
|
||||
the default method is symmetric quant, which is compatible with old MNN versions. \
|
||||
you can set this flag to True use asymmetric quant method to improve accuracy of the weight-quant model in some cases, \
|
||||
but asymmetric quant model cannot run on old MNN versions. You will need to upgrade MNN to new version to solve this problem. \
|
||||
default: False, which means using SYMMETRIC quant method}]")
|
||||
print(" [--compressionParamsFile COMPRESSION_PARAMS_PATH]")
|
||||
|
||||
def main():
|
||||
""" main funcion """
|
||||
accepted_framework = ['TF', 'CAFFE', 'ONNX', 'TFLITE', 'MNN']
|
||||
TF, CAFFE, ONNX, MNN, TFLITE = 0, 1, 2, 3, 4
|
||||
framework_map = {'TF': TF, 'CAFFE': CAFFE, 'ONNX': ONNX, 'TFLITE': TFLITE, 'MNN': MNN}
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f", "--framework", type=str,\
|
||||
choices=['TF', 'CAFFE', 'ONNX', 'TFLITE', 'MNN'], default='TF',\
|
||||
required=True, help="model type, for example:TF/CAFFE/ONNX/TFLITE/MNN")
|
||||
choices=list(framework_map.keys()), default='TF', required=True, help="model type")
|
||||
parser.add_argument("--modelFile", type=str, required=True,\
|
||||
help="tensorflow Pb or caffeModel, for example:xxx.pb/xxx.caffemodel")
|
||||
parser.add_argument("--prototxt", type=str,\
|
||||
help="only used for caffe, for example: xxx.prototxt")
|
||||
parser.add_argument("--MNNModel", type=str, required=True,\
|
||||
help="MNN model, ex: xxx.mnn")
|
||||
parser.add_argument("--prototxt", type=str, help="only used for caffe, for example: xxx.prototxt")
|
||||
parser.add_argument("--MNNModel", type=str, required=True, help="MNN model, ex: xxx.mnn")
|
||||
parser.add_argument("--bizCode", type=str, required=True, help="bizcode, ex: MNN")
|
||||
parser.add_argument("--fp16", type=bool, default=False,\
|
||||
help="{True,False}\
|
||||
Boolean to change the mnn usage. If True, the output\
|
||||
|
@ -45,31 +30,13 @@ def main():
|
|||
help="The path of model compression file that stores the int8 calibration \
|
||||
table for quantization or auxiliary parameters for sparsity.")
|
||||
|
||||
TF = 0
|
||||
CAFFE = 1
|
||||
ONNX = 2
|
||||
MNN = 3
|
||||
TFLITE = 4
|
||||
args = parser.parse_args()
|
||||
if args.framework.upper() in accepted_framework:
|
||||
if args.framework == 'TF':
|
||||
framework_type = TF
|
||||
elif args.framework.upper() == 'CAFFE':
|
||||
framework_type = CAFFE
|
||||
elif args.framework.upper() == 'ONNX':
|
||||
framework_type = ONNX
|
||||
elif args.framework.upper() == 'MNN':
|
||||
framework_type = MNN
|
||||
elif args.framework.upper() == 'TFLITE':
|
||||
framework_type = TFLITE
|
||||
else:
|
||||
usage()
|
||||
return -1
|
||||
framework_type = framework_map[args.framework]
|
||||
if args.modelFile is None or not os.path.exists(args.modelFile):
|
||||
print("modelfile not exist")
|
||||
return -1
|
||||
if args.MNNModel is None:
|
||||
usage()
|
||||
parser.print_help(sys.stderr)()
|
||||
return -1
|
||||
if args.framework.upper() == 'CAFFE':
|
||||
if args.prototxt is None or not os.path.exists(args.prototxt):
|
||||
|
@ -86,7 +53,7 @@ def main():
|
|||
args.compressionParamsFile = ""
|
||||
|
||||
Tools.mnnconvert(args.MNNModel, args. modelFile, framework_type,\
|
||||
args.fp16, args.prototxt, args.weightQuantBits, args.weightQuantAsymmetric, args.compressionParamsFile)
|
||||
args.fp16, args.prototxt, args.weightQuantBits, args.weightQuantAsymmetric, args.compressionParamsFile, args.bizCode)
|
||||
return 0
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -185,6 +185,7 @@ def configure_extension_build():
|
|||
tools_include_dirs += [os.path.join(root_dir, "source", "core")]
|
||||
tools_include_dirs += [os.path.join(root_dir, "schema", "current")]
|
||||
tools_include_dirs += [os.path.join(root_dir, "source")]
|
||||
tools_include_dirs += [np.get_include()]
|
||||
if IS_WINDOWS:
|
||||
tools_include_dirs += [os.path.join(os.environ['Protobuf_SRC_ROOT_FOLDER'], 'src')]
|
||||
|
||||
|
@ -206,7 +207,6 @@ def configure_extension_build():
|
|||
engine_extra_link_args += ['-Wl,--no-whole-archive']
|
||||
if IS_WINDOWS:
|
||||
engine_extra_link_args += ['/WHOLEARCHIVE:MNN.lib']
|
||||
engine_extra_link_args += ['/WHOLEARCHIVE:MNNTrain.lib']
|
||||
if IS_DARWIN:
|
||||
tools_extra_link_args += ['-Wl,-all_load']
|
||||
tools_extra_link_args += tools_depend
|
||||
|
|
138
pymnn/src/MNN.cc
138
pymnn/src/MNN.cc
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
#include "MNNPyBridge.h"
|
||||
#include "common.h"
|
||||
#include "util.h"
|
||||
|
||||
static int tls_key = 0;
|
||||
static int tls_key_2 = 0;
|
||||
|
@ -28,8 +29,10 @@ namespace py = pybind11;
|
|||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include <MNN/expr/Executor.hpp>
|
||||
//#include <MNN/expr/ExecutorScope.hpp>
|
||||
#include <MNN/expr/NN.hpp>
|
||||
#include <MNN/expr/Module.hpp>
|
||||
using namespace MNN::Express;
|
||||
#endif // PYMNN_EXPR_API
|
||||
|
||||
#ifdef BUILD_OPTYPE
|
||||
|
@ -45,15 +48,15 @@ namespace py = pybind11;
|
|||
#include "DataLoader.hpp"
|
||||
#include "Loss.hpp"
|
||||
#include "Transformer.hpp"
|
||||
#include "PipelineModule.hpp"
|
||||
using namespace MNN::Train;
|
||||
#endif // PYMNN_TRAIN_API
|
||||
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include "util.h"
|
||||
|
||||
using namespace MNN;
|
||||
using namespace MNN::Express;
|
||||
|
||||
using namespace std;
|
||||
|
||||
struct MNN_TLSData {
|
||||
|
@ -598,6 +601,8 @@ static PyObject* PyMNNInterpreter_createSession(PyMNNInterpreter *self, PyObject
|
|||
config.type = MNN_FORWARD_CPU;
|
||||
if (backend) {
|
||||
auto backend_name = object2String(backend);
|
||||
// Avoid misusing backend not supported by the bridge and corresponding MNN library on python level,
|
||||
// then user will ask for right version bridge library to us, same like MNN.expr.Backend.* python enum
|
||||
std::unordered_map<std::string, MNNForwardType> backend_map = {
|
||||
{"CPU", MNN_FORWARD_CPU},
|
||||
#ifdef MNN_OPENCL
|
||||
|
@ -617,10 +622,14 @@ static PyObject* PyMNNInterpreter_createSession(PyMNNInterpreter *self, PyObject
|
|||
#endif
|
||||
#ifdef MNN_CUDA
|
||||
{"CUDA", MNN_FORWARD_CUDA},
|
||||
#endif
|
||||
#ifdef MNN_HIAI
|
||||
{"HIAI", MNN_FORWARD_USER_0}
|
||||
#endif
|
||||
};
|
||||
auto iter = backend_map.find(backend_name);
|
||||
if (iter == backend_map.end()) {
|
||||
// backend not support, issue on python level when development
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"PyMNNInterpreter_createSession: backend not support");
|
||||
return NULL;
|
||||
|
@ -1117,8 +1126,8 @@ static int PyMNNInterpreter_init(PyMNNInterpreter *self, PyObject *args, PyObjec
|
|||
"PyMNNInterpreter_new: PyArg_ParseTuple failed");
|
||||
return -1;
|
||||
}
|
||||
|
||||
self->modelPath = new std::string(path);
|
||||
auto converted_path = convertBytesEncodeIfNeed(path);
|
||||
self->modelPath = new std::string(converted_path.data());
|
||||
if (!self->modelPath) {
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"PyMNNInterpreter_new: create modelPath string failed");
|
||||
|
@ -1517,7 +1526,7 @@ static PyObject* PyMNNTensor_getNumpyData(PyMNNTensor *self, PyObject *args) {
|
|||
auto data = self->tensor->host<double>();
|
||||
obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_DOUBLE, data);
|
||||
} else {
|
||||
MNN_PRINT("tensor can not be read as numpy\n");
|
||||
PyErr_SetString(PyExc_Exception, "tensor can not be read as numpy");
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
return obj;
|
||||
|
@ -2142,27 +2151,27 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) {
|
|||
#endif
|
||||
|
||||
if (PyType_Ready(&PyMNNInterpreterType) < 0) {
|
||||
printf("initMNN: PyType_Ready PyMNNInterpreterType failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNInterpreterType failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
if (PyType_Ready(&PyMNNSessionType) < 0) {
|
||||
printf("initMNN: PyType_Ready PyMNNSessionType failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNSessionType failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
if (PyType_Ready(&PyMNNTensorType) < 0) {
|
||||
printf("initMNN: PyType_Ready PyMNNTensorType failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNTensorType failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
if (PyType_Ready(&PyMNNCVImageProcessType) < 0) {
|
||||
printf("initMNN: PyType_Ready PyMNNCVImageProcessType failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNCVImageProcessType failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
if (PyType_Ready(&PyMNNCVMatrixType) < 0) {
|
||||
printf("initMNN: PyType_Ready PyMNNCVMatrixType failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNCVMatrixType failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
if (PyType_Ready(&PyMNNOpInfoType) < 0) {
|
||||
printf("initMNN: PyType_Ready PyMNNOpInfoType failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNOpInfoType failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
@ -2172,12 +2181,12 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) {
|
|||
#endif
|
||||
// module import failed!
|
||||
if (!m) {
|
||||
printf("initMNN: import MNN failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: import MNN failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
#ifdef PYMNN_NUMPY_USABLE
|
||||
if(_import_array() < 0) {
|
||||
printf("initMNN: init numpy failed");
|
||||
PyErr_SetString(PyExc_Exception, "initMNN: init numpy failed");
|
||||
ERROR_RETURN
|
||||
}
|
||||
#endif
|
||||
|
@ -2614,18 +2623,67 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) {
|
|||
exe->gc(Executor::PART);
|
||||
}
|
||||
});
|
||||
expr_module.def("set_thread_number",
|
||||
[](int numberThread) {
|
||||
if (numberThread < 1) {
|
||||
numberThread = 1;
|
||||
}
|
||||
if (numberThread > 8) {
|
||||
numberThread = 8;
|
||||
py::enum_<MNNForwardType>(expr_module, "Backend")
|
||||
.value("CPU", MNN_FORWARD_CPU)
|
||||
#ifdef MNN_OPENCL
|
||||
.value("OPENCL", MNN_FORWARD_OPENCL)
|
||||
#endif
|
||||
#ifdef MNN_OPENGL
|
||||
.value("OPENGL", MNN_FORWARD_OPENGL)
|
||||
#endif
|
||||
#ifdef MNN_VULKAN
|
||||
.value("VULKAN", MNN_FORWARD_VULKAN)
|
||||
#endif
|
||||
#ifdef MNN_METAL
|
||||
.value("METAL", MNN_FORWARD_METAL)
|
||||
#endif
|
||||
#ifdef MNN_TENSORRT
|
||||
.value("TRT", MNN_FORWARD_USER_1)
|
||||
#endif
|
||||
#ifdef MNN_CUDA
|
||||
.value("CUDA", MNN_FORWARD_CUDA)
|
||||
#endif
|
||||
#ifdef MNN_HIAI
|
||||
.value("HIAI", MNN_FORWARD_USER_0)
|
||||
#endif
|
||||
.export_values();
|
||||
|
||||
using MemoryMode = BackendConfig::MemoryMode;
|
||||
using PowerMode = BackendConfig::PowerMode;
|
||||
using PrecisionMode = BackendConfig::PrecisionMode;
|
||||
py::enum_<MemoryMode>(expr_module, "MemoryMode")
|
||||
.value("Normal", MemoryMode::Memory_Normal)
|
||||
.value("High", MemoryMode::Memory_High)
|
||||
.value("Low", MemoryMode::Memory_Low)
|
||||
.export_values();
|
||||
py::enum_<PowerMode>(expr_module, "PowerMode")
|
||||
.value("Normal", PowerMode::Power_Normal)
|
||||
.value("High", PowerMode::Power_High)
|
||||
.value("Low", PowerMode::Power_Low)
|
||||
.export_values();
|
||||
py::enum_<PrecisionMode>(expr_module, "PrecisionMode")
|
||||
.value("Normal", PrecisionMode::Precision_Normal)
|
||||
.value("High", PrecisionMode::Precision_High)
|
||||
.value("Low", PrecisionMode::Precision_Low)
|
||||
.export_values();
|
||||
expr_module.def("set_config",
|
||||
[](MNNForwardType backend, MemoryMode memory_mode, PowerMode power_mode, PrecisionMode precision_mode, int thread_num) {
|
||||
if (thread_num < 1 || thread_num > 8) {
|
||||
PyErr_SetString(PyExc_Exception, "thread_num should bigger than 0 and less than 9");
|
||||
}
|
||||
thread_num = std::max(std::min(thread_num, 8), 1);
|
||||
//auto exe = ExecutorScope::Current();
|
||||
auto exe = Executor::getGlobalExecutor();
|
||||
BackendConfig config;
|
||||
exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, numberThread);
|
||||
});
|
||||
config.memory = memory_mode;
|
||||
config.power = power_mode;
|
||||
config.precision = precision_mode;
|
||||
exe->setGlobalExecutorConfig(backend, config, thread_num);
|
||||
},
|
||||
py::arg("backend")=MNN_FORWARD_CPU, py::arg("memory_mode")=MemoryMode::Memory_Normal,
|
||||
py::arg("power_mode")=PowerMode::Power_Normal, py::arg("precision_mode")=PrecisionMode::Precision_Normal,
|
||||
py::arg("thread_num")=1);
|
||||
|
||||
//Begin of Math OPS
|
||||
//Unary OPS
|
||||
expr_module.def("sign", &Express::_Sign);
|
||||
|
@ -3018,12 +3076,32 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) {
|
|||
return Module::extract(inputs, outputs, fortrain);
|
||||
});
|
||||
nn_module.def("load_module_from_file", [](const vector<string>& inputs, const vector<string>& outputs,
|
||||
const char* file_name, bool dynamic, bool shape_mutable) -> Module* {
|
||||
//Module::Config config {dynamic, shape_mutable};
|
||||
const char* file_name, bool dynamic, bool shape_mutable, bool rearrange,
|
||||
MNNForwardType backend, MemoryMode memory_mode, PowerMode power_mode,
|
||||
PrecisionMode precision_mode, int thread_num) -> Module* {
|
||||
BackendConfig backend_config;
|
||||
backend_config.memory = memory_mode;
|
||||
backend_config.power = power_mode;
|
||||
backend_config.precision = precision_mode;
|
||||
|
||||
Module::BackendInfo backend_info;
|
||||
backend_info.type = backend;
|
||||
backend_info.config = &backend_config;
|
||||
|
||||
Module::Config config;
|
||||
config.dynamic = dynamic;
|
||||
config.shapeMutable = shape_mutable;
|
||||
return Module::load(inputs, outputs, file_name, &config);
|
||||
config.rearrange = rearrange;
|
||||
config.backend = &backend_info;
|
||||
|
||||
auto converted_file_name = convertBytesEncodeIfNeed(file_name);
|
||||
auto m_ptr = Module::load(inputs, outputs, converted_file_name.data(), &config);
|
||||
if (m_ptr == nullptr) {
|
||||
std::string mnn_errno = "load_module_from_file failed ";
|
||||
mnn_errno = mnn_errno + std::string(file_name);
|
||||
PyErr_SetString(PyExc_Exception, mnn_errno.c_str());
|
||||
}
|
||||
return m_ptr;
|
||||
});
|
||||
|
||||
// CNN
|
||||
|
@ -3188,11 +3266,11 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) {
|
|||
.value("MAXIMUM", NN::Maximum)
|
||||
.value("MOVING_AVERAGE", NN::MovingAverage)
|
||||
.export_values();
|
||||
// compress_module.def("train_quant", &PipelineModule::turnQuantize,
|
||||
// py::arg("module"),
|
||||
// py::arg("quant_bits") = 8,
|
||||
// py::arg("feature_scale_method") = NN::FeatureScaleStatMethod::PerTensor,
|
||||
// py::arg("scale_update_method") = NN::ScaleUpdateMethod::MovingAverage);
|
||||
compress_module.def("train_quant", &PipelineModule::turnQuantize,
|
||||
py::arg("module"),
|
||||
py::arg("quant_bits") = 8,
|
||||
py::arg("feature_scale_method") = NN::FeatureScaleStatMethod::PerTensor,
|
||||
py::arg("scale_update_method") = NN::ScaleUpdateMethod::MovingAverage);
|
||||
}
|
||||
// End of Train
|
||||
#endif
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
*/
|
||||
#include <Python.h>
|
||||
#include "structmember.h"
|
||||
|
||||
#include "util.h"
|
||||
#include "MNN_generated.h"
|
||||
#include "PostConverter.hpp"
|
||||
#include "addBizCode.hpp"
|
||||
|
@ -13,7 +13,6 @@
|
|||
#include "tensorflowConverter.hpp"
|
||||
#include "writeFb.hpp"
|
||||
#include "config.hpp"
|
||||
#include "options.hpp"
|
||||
#include "common/Global.hpp"
|
||||
#include "calibration.hpp"
|
||||
#include "logkit.h"
|
||||
|
@ -27,48 +26,48 @@ static PyObject* PyTool_Converter(PyObject *self, PyObject *args) {
|
|||
const char* modelFile = NULL;
|
||||
const char* compressionParamsFile = NULL;
|
||||
const char* prototxtFile = NULL;
|
||||
const char* bizCode = NULL;
|
||||
PyObject* frameworkType = NULL;
|
||||
PyObject* fp16 = NULL;
|
||||
PyObject* weightQuantBits = NULL;
|
||||
PyObject* weightQuantAsymmetric = NULL;
|
||||
if (!PyArg_ParseTuple(args, "ssOO|sOOs", &mnnModel, &modelFile,
|
||||
if (!PyArg_ParseTuple(args, "ssOO|sOOss", &mnnModel, &modelFile,
|
||||
&frameworkType, &fp16, &prototxtFile,
|
||||
&weightQuantBits, &weightQuantAsymmetric, &compressionParamsFile)) {
|
||||
&weightQuantBits, &weightQuantAsymmetric, &compressionParamsFile,
|
||||
&bizCode)) {
|
||||
return NULL;
|
||||
}
|
||||
struct modelConfig modelPath;
|
||||
modelPath.MNNModel = std::string(mnnModel);
|
||||
modelPath.modelFile = std::string(modelFile);
|
||||
modelPath.MNNModel = convertBytesEncodeIfNeed(mnnModel);
|
||||
modelPath.modelFile = convertBytesEncodeIfNeed(modelFile);
|
||||
modelPath.model = static_cast<modelConfig::MODEL_SOURCE>(PyLong_AsLong(frameworkType));
|
||||
modelPath.bizCode = std::string("");
|
||||
modelPath.bizCode = std::string(bizCode);
|
||||
modelPath.benchmarkModel = false;
|
||||
modelPath.saveHalfFloat = static_cast<bool>(PyLong_AsLong(fp16));
|
||||
modelPath.forTraining = false;
|
||||
modelPath.weightQuantBits = static_cast<int>(PyLong_AsLong(weightQuantBits));
|
||||
modelPath.weightQuantAsymmetric = static_cast<bool>(PyLong_AsLong(weightQuantAsymmetric));
|
||||
if(prototxtFile){
|
||||
modelPath.prototxtFile = std::string(prototxtFile);
|
||||
modelPath.prototxtFile = convertBytesEncodeIfNeed(prototxtFile);
|
||||
}
|
||||
|
||||
common::Options options;
|
||||
if (compressionParamsFile) {
|
||||
modelPath.compressionParamsFile = std::string(compressionParamsFile);
|
||||
options = common::BuildOptions(modelPath.compressionParamsFile);
|
||||
modelPath.compressionParamsFile = convertBytesEncodeIfNeed(compressionParamsFile);
|
||||
}
|
||||
|
||||
Global<modelConfig>::Reset(&modelPath);
|
||||
|
||||
std::unique_ptr<MNN::NetT> netT = std::unique_ptr<MNN::NetT>(new MNN::NetT());
|
||||
if (modelPath.model == modelConfig::CAFFE) {
|
||||
caffe2MNNNet(modelPath.prototxtFile, modelPath.modelFile, modelPath.bizCode, options, netT);
|
||||
caffe2MNNNet(modelPath.prototxtFile, modelPath.modelFile, modelPath.bizCode, netT);
|
||||
} else if (modelPath.model == modelConfig::TENSORFLOW) {
|
||||
tensorflow2MNNNet(modelPath.modelFile, modelPath.bizCode, options, netT);
|
||||
tensorflow2MNNNet(modelPath.modelFile, modelPath.bizCode, netT);
|
||||
} else if (modelPath.model == modelConfig::MNN) {
|
||||
addBizCode(modelPath.modelFile, modelPath.bizCode, options, netT);
|
||||
addBizCode(modelPath.modelFile, modelPath.bizCode, netT);
|
||||
} else if (modelPath.model == modelConfig::ONNX) {
|
||||
onnx2MNNNet(modelPath.modelFile, modelPath.bizCode, options, netT);
|
||||
onnx2MNNNet(modelPath.modelFile, modelPath.bizCode, netT);
|
||||
} else if (modelPath.model == modelConfig::TFLITE) {
|
||||
tflite2MNNNet(modelPath.modelFile, modelPath.bizCode, options, netT);
|
||||
tflite2MNNNet(modelPath.modelFile, modelPath.bizCode, netT);
|
||||
} else {
|
||||
std::cout << "Not Support Model Type" << std::endl;
|
||||
}
|
||||
|
|
|
@ -50,4 +50,4 @@ static int global_new_python_flag = 0;
|
|||
#include <Python.h>
|
||||
#include "structmember.h"
|
||||
#include "numpy/arrayobject.h"
|
||||
#endif // PYMNN_USE_ALINNPYTHON
|
||||
#endif // PYMNN_USE_ALINNPYTHON
|
||||
|
|
|
@ -1,10 +1,44 @@
|
|||
#pragma once
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <MNN/HalideRuntime.h>
|
||||
#if defined(_MSC_VER) && PY_MAJOR_VERSION >= 3
|
||||
#include <Windows.h>
|
||||
#include <stringapiset.h>
|
||||
#endif
|
||||
#include "common.h"
|
||||
|
||||
using namespace std;
|
||||
typedef vector<int> INTS;
|
||||
|
||||
// In python3, default str is unicode, then be transformed to UTF-8 bytes by pybind.
|
||||
// In Windows, MNN library assume input bytes be encoded by CP_ACP.
|
||||
// So we need: UTF-8 bytes -> unicodes -> CP_ACP bytes
|
||||
inline std::string convertBytesEncodeIfNeed(const char* srcBytes) {
|
||||
#if defined(_MSC_VER) && PY_MAJOR_VERSION >= 3
|
||||
int wideCharSize = MultiByteToWideChar(CP_UTF8, 0, srcBytes, -1, nullptr, 0);
|
||||
if (wideCharSize == 0) {
|
||||
return {};
|
||||
}
|
||||
std::unique_ptr<wchar_t[]> unicodes(new wchar_t[wideCharSize]);
|
||||
if (MultiByteToWideChar(CP_UTF8, 0, srcBytes, -1, unicodes.get(), wideCharSize) == 0) {
|
||||
return {};
|
||||
}
|
||||
int byteSize = WideCharToMultiByte(CP_ACP, 0, unicodes.get(), wideCharSize, nullptr, 0, nullptr, nullptr);
|
||||
if (byteSize == 0) {
|
||||
return {};
|
||||
}
|
||||
std::unique_ptr<char[]> dstBytes(new char[byteSize]);
|
||||
if (WideCharToMultiByte(CP_ACP, 0, unicodes.get(), wideCharSize, dstBytes.get(), byteSize, nullptr, nullptr) == 0) {
|
||||
return {};
|
||||
}
|
||||
return {dstBytes.get(), (size_t)byteSize};
|
||||
#else
|
||||
return {srcBytes};
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns true if obj is a bytes/str or unicode object
|
||||
inline bool checkString(PyObject* obj) {
|
||||
return PyBytes_Check(obj) || PyUnicode_Check(obj);
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -p python_version [-t]"
|
||||
echo -e "\t-p python versions in pyenv"
|
||||
echo "Usage: $0 -p python_version -v mnn_version [-t]"
|
||||
echo -e "\t-p python versions in pyenv [only support 2.x]"
|
||||
echo -e "\t-v MNN version to set"
|
||||
echo -e "\t-t include train API wrapper"
|
||||
exit 1
|
||||
}
|
||||
|
||||
while getopts "p:t" opt; do
|
||||
while getopts "p:v:t" opt; do
|
||||
case "$opt" in
|
||||
p ) py_version=$OPTARG ;;
|
||||
v ) mnn_version=$OPTARG ;;
|
||||
t ) train_api=true ;;
|
||||
* ) usage ;;
|
||||
esac
|
||||
|
@ -20,6 +23,7 @@ cp -r pip_package/MNN /tmp/mnn_py
|
|||
pushd /tmp/mnn_py/MNN
|
||||
|
||||
rm -rf tools
|
||||
echo -e "__version__ = '$mnn_version'" > version.py
|
||||
cat __init__.py | sed '/from . import tools/d' > __init__.py.tmp
|
||||
mv __init__.py.tmp __init__.py
|
||||
|
||||
|
@ -32,14 +36,41 @@ fi
|
|||
find . -name __pycache__ | xargs rm -rf
|
||||
pyenv global $py_version
|
||||
python -c "import compileall; compileall.compile_dir('/tmp/mnn_py/MNN', force=True)"
|
||||
find . -name *.py | xargs rm -rf
|
||||
find . -name "*.py" | xargs rm -rf
|
||||
cd ..
|
||||
zip -r MNN.zip MNN
|
||||
popd
|
||||
|
||||
rm -f android/src/main/assets/MNN.zip
|
||||
rm -rf iOS/MNNPyBridge/lib/MNN
|
||||
cp /tmp/mnn_py/MNN.zip android/src/main/assets
|
||||
cp -r /tmp/mnn_py/MNN iOS/MNNPyBridge/lib
|
||||
# update wrapper assets from $1 to $2 when pyc (WITHOUT METADATA) is not same
|
||||
should_update () {
|
||||
pushd $1
|
||||
pyc_files_1=(`find MNN -name *.pyc | sort`)
|
||||
popd
|
||||
pushd $2
|
||||
pyc_files_2=(`find MNN -name *.pyc | sort`)
|
||||
popd
|
||||
if [ ${#pyc_files_1[@]} -ne ${#pyc_files_2[@]} ]; then
|
||||
return 0
|
||||
fi
|
||||
for ((i=0;i<${#pyc_files_1[@]};i++)); do
|
||||
if [ ${pyc_files_1[i]} != ${pyc_files_2[i]} ]; then
|
||||
return 0
|
||||
fi
|
||||
pyc_file=${pyc_files_1[i]}
|
||||
sum_old=`tail -c +8 $2/$pyc_file | md5sum | awk '{print $1}'`
|
||||
sum_new=`tail -c +8 $1/$pyc_file | md5sum | awk '{print $1}'`
|
||||
if [ $sum_old != $sum_new ]; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if should_update /tmp/mnn_py iOS/MNNPyBridge/lib; then
|
||||
rm -f android/src/main/assets/MNN.zip
|
||||
rm -rf iOS/MNNPyBridge/lib/MNN
|
||||
cp /tmp/mnn_py/MNN.zip android/src/main/assets
|
||||
cp -r /tmp/mnn_py/MNN iOS/MNNPyBridge/lib
|
||||
fi
|
||||
|
||||
rm -rf /tmp/mnn_py
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
# Copies from the files from Gitlab AliNN/MNN to Github MNN repo,
|
||||
# and remove some internal files.
|
||||
# This scripts assumes:
|
||||
# 1. the current directory is the parent directory of "MNN"
|
||||
# 2. the current directory contains the "GithubMNN" directory
|
||||
|
||||
SOURCE="MNN"
|
||||
TARGET="GithubMNN"
|
||||
|
||||
# check dirs
|
||||
if [ ! -d $SOURCE ]; then
|
||||
echo "$SOURCE Not Found"
|
||||
exit -1
|
||||
fi
|
||||
if [ ! -d $TARGET ]; then
|
||||
echo "$TARGET Not Found"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
# remove files except .git in $TARGET
|
||||
pushd $TARGET > /dev/null
|
||||
ls | grep -v .git | xargs rm -rf
|
||||
rm -f .gitignore
|
||||
popd > /dev/null
|
||||
|
||||
# copy files from $SOURCE to $TARGET
|
||||
pushd $SOURCE > /dev/null
|
||||
ls | grep -v .git | xargs -I {} cp -af {} ../$TARGET
|
||||
cp -f .gitignore ../$TARGET
|
||||
popd > /dev/null
|
||||
|
||||
# reverting files
|
||||
pushd $TARGET > /dev/null
|
||||
# git clean -df
|
||||
popd > /dev/null
|
|
@ -1,63 +0,0 @@
|
|||
# Copies from the files from Gitlab AliNN/AliNNPrivate to Gitlab AliNN/MNN repo,
|
||||
# and remove some internal files.
|
||||
# This scripts assumes:
|
||||
# 1. the current directory is the parent directory of "AliNNPrivate"
|
||||
# 2. the current directory contains the "MNN" directory
|
||||
|
||||
SOURCE="AliNNPrivate"
|
||||
TARGET="MNN"
|
||||
|
||||
# check dirs
|
||||
if [ ! -d $SOURCE ]; then
|
||||
echo "$SOURCE Not Found"
|
||||
exit -1
|
||||
fi
|
||||
if [ ! -d $TARGET ]; then
|
||||
echo "$TARGET Not Found"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
# remove files except .git in $TARGET
|
||||
pushd $TARGET > /dev/null
|
||||
ls | grep -v .git | xargs rm -rf
|
||||
rm -f .gitignore
|
||||
popd > /dev/null
|
||||
|
||||
# copy files from $SOURCE to $TARGET
|
||||
pushd $SOURCE > /dev/null
|
||||
# Remove gitignored and untracked files.
|
||||
git clean -df
|
||||
|
||||
ls | grep -v .git | xargs -I {} cp -af {} ../$TARGET
|
||||
cp -f .gitignore ../$TARGET
|
||||
rm -rf ../$TARGET/release_scripts
|
||||
rm -rf ../$TARGET/pymnn/android
|
||||
rm -rf ../$TARGET/pymnn/iOS
|
||||
rm -f ../$TARGET/pymnn/renameForAliNNPython.h
|
||||
rm -f ../$TARGET/pymnn/src/private_define.h
|
||||
rm -f ../$TARGET/pymnn/src/renameForAliNNPython.h
|
||||
rm -f ../$TARGET/pymnn/MNNBridge.podspec
|
||||
rm -f ../$TARGET/source/backend/hiai/3rdParty
|
||||
popd > /dev/null
|
||||
|
||||
# reverting files
|
||||
pushd $TARGET > /dev/null
|
||||
git checkout -- benchmark/models/*.mnn
|
||||
git checkout -- project/android/build.gradle
|
||||
popd > /dev/null
|
||||
|
||||
# try re-build
|
||||
pushd $TARGET > /dev/null
|
||||
|
||||
# MNN
|
||||
rm -rf build
|
||||
rm -rf schema/private
|
||||
rm -rf schema/current
|
||||
|
||||
./schema/generate.sh
|
||||
mkdir build && cd build
|
||||
cmake .. -DMNN_BUILD_TEST=true -DMNN_BUILD_CONVERTER=true -DMNN_BUILD_QUANTOOLS=true
|
||||
make -j4
|
||||
./run_test.out
|
||||
|
||||
popd > /dev/null
|
|
@ -45,6 +45,9 @@ struct TensorDescribeT;
|
|||
struct SubGraphProto;
|
||||
struct SubGraphProtoT;
|
||||
|
||||
struct TensorQuantInfo;
|
||||
struct TensorQuantInfoT;
|
||||
|
||||
struct Net;
|
||||
struct NetT;
|
||||
|
||||
|
@ -68,6 +71,8 @@ inline const flatbuffers::TypeTable *TensorDescribeTypeTable();
|
|||
|
||||
inline const flatbuffers::TypeTable *SubGraphProtoTypeTable();
|
||||
|
||||
inline const flatbuffers::TypeTable *TensorQuantInfoTypeTable();
|
||||
|
||||
inline const flatbuffers::TypeTable *NetTypeTable();
|
||||
|
||||
enum OpType {
|
||||
|
@ -207,6 +212,7 @@ enum OpType {
|
|||
OpType_TensorArraySplit = 139,
|
||||
OpType_TensorArrayConcat = 140,
|
||||
OpType_LSTMBlockCell = 141,
|
||||
OpType_Reverse = 142,
|
||||
OpType_Plugin = 256,
|
||||
OpType_Select = 257,
|
||||
OpType_ZerosLike = 258,
|
||||
|
@ -230,11 +236,12 @@ enum OpType {
|
|||
OpType_While = 600,
|
||||
OpType_If = 601,
|
||||
OpType_LayerNorm = 603,
|
||||
OpType_GridSample = 604,
|
||||
OpType_MIN = OpType_AbsVal,
|
||||
OpType_MAX = OpType_LayerNorm
|
||||
OpType_MAX = OpType_GridSample
|
||||
};
|
||||
|
||||
inline const OpType (&EnumValuesOpType())[159] {
|
||||
inline const OpType (&EnumValuesOpType())[161] {
|
||||
static const OpType values[] = {
|
||||
OpType_AbsVal,
|
||||
OpType_QuantizedAdd,
|
||||
|
@ -372,6 +379,7 @@ inline const OpType (&EnumValuesOpType())[159] {
|
|||
OpType_TensorArraySplit,
|
||||
OpType_TensorArrayConcat,
|
||||
OpType_LSTMBlockCell,
|
||||
OpType_Reverse,
|
||||
OpType_Plugin,
|
||||
OpType_Select,
|
||||
OpType_ZerosLike,
|
||||
|
@ -394,7 +402,8 @@ inline const OpType (&EnumValuesOpType())[159] {
|
|||
OpType_EltwiseInt8,
|
||||
OpType_While,
|
||||
OpType_If,
|
||||
OpType_LayerNorm
|
||||
OpType_LayerNorm,
|
||||
OpType_GridSample
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
@ -543,7 +552,7 @@ inline const char * const *EnumNamesOpType() {
|
|||
"TensorArraySplit",
|
||||
"TensorArrayConcat",
|
||||
"LSTMBlockCell",
|
||||
"",
|
||||
"Reverse",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
|
@ -1005,13 +1014,14 @@ inline const char * const *EnumNamesOpType() {
|
|||
"If",
|
||||
"",
|
||||
"LayerNorm",
|
||||
"GridSample",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameOpType(OpType e) {
|
||||
if (e < OpType_AbsVal || e > OpType_LayerNorm) return "";
|
||||
if (e < OpType_AbsVal || e > OpType_GridSample) return "";
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesOpType()[index];
|
||||
}
|
||||
|
@ -1108,11 +1118,12 @@ enum OpParameter {
|
|||
OpParameter_LayerNorm = 88,
|
||||
OpParameter_TensorArray = 89,
|
||||
OpParameter_LSTMBlockCell = 90,
|
||||
OpParameter_GridSample = 91,
|
||||
OpParameter_MIN = OpParameter_NONE,
|
||||
OpParameter_MAX = OpParameter_LSTMBlockCell
|
||||
OpParameter_MAX = OpParameter_GridSample
|
||||
};
|
||||
|
||||
inline const OpParameter (&EnumValuesOpParameter())[91] {
|
||||
inline const OpParameter (&EnumValuesOpParameter())[92] {
|
||||
static const OpParameter values[] = {
|
||||
OpParameter_NONE,
|
||||
OpParameter_QuantizedAdd,
|
||||
|
@ -1204,7 +1215,8 @@ inline const OpParameter (&EnumValuesOpParameter())[91] {
|
|||
OpParameter_RandomUniform,
|
||||
OpParameter_LayerNorm,
|
||||
OpParameter_TensorArray,
|
||||
OpParameter_LSTMBlockCell
|
||||
OpParameter_LSTMBlockCell,
|
||||
OpParameter_GridSample
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
@ -1302,13 +1314,14 @@ inline const char * const *EnumNamesOpParameter() {
|
|||
"LayerNorm",
|
||||
"TensorArray",
|
||||
"LSTMBlockCell",
|
||||
"GridSample",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameOpParameter(OpParameter e) {
|
||||
if (e < OpParameter_NONE || e > OpParameter_LSTMBlockCell) return "";
|
||||
if (e < OpParameter_NONE || e > OpParameter_GridSample) return "";
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesOpParameter()[index];
|
||||
}
|
||||
|
@ -1677,6 +1690,10 @@ template<> struct OpParameterTraits<LSTMBlockCell> {
|
|||
static const OpParameter enum_value = OpParameter_LSTMBlockCell;
|
||||
};
|
||||
|
||||
template<> struct OpParameterTraits<GridSample> {
|
||||
static const OpParameter enum_value = OpParameter_GridSample;
|
||||
};
|
||||
|
||||
struct OpParameterUnion {
|
||||
OpParameter type;
|
||||
void *value;
|
||||
|
@ -2428,6 +2445,14 @@ struct OpParameterUnion {
|
|||
return type == OpParameter_LSTMBlockCell ?
|
||||
reinterpret_cast<const LSTMBlockCellT *>(value) : nullptr;
|
||||
}
|
||||
GridSampleT *AsGridSample() {
|
||||
return type == OpParameter_GridSample ?
|
||||
reinterpret_cast<GridSampleT *>(value) : nullptr;
|
||||
}
|
||||
const GridSampleT *AsGridSample() const {
|
||||
return type == OpParameter_GridSample ?
|
||||
reinterpret_cast<const GridSampleT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type);
|
||||
|
@ -3316,6 +3341,9 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const LSTMBlockCell *main_as_LSTMBlockCell() const {
|
||||
return main_type() == OpParameter_LSTMBlockCell ? static_cast<const LSTMBlockCell *>(main()) : nullptr;
|
||||
}
|
||||
const GridSample *main_as_GridSample() const {
|
||||
return main_type() == OpParameter_GridSample ? static_cast<const GridSample *>(main()) : nullptr;
|
||||
}
|
||||
const flatbuffers::String *name() const {
|
||||
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
||||
}
|
||||
|
@ -3708,6 +3736,10 @@ template<> inline const LSTMBlockCell *Op::main_as<LSTMBlockCell>() const {
|
|||
return main_as_LSTMBlockCell();
|
||||
}
|
||||
|
||||
template<> inline const GridSample *Op::main_as<GridSample>() const {
|
||||
return main_as_GridSample();
|
||||
}
|
||||
|
||||
struct OpBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
|
@ -3983,6 +4015,7 @@ struct TensorDescribeT : public flatbuffers::NativeTable {
|
|||
int32_t index;
|
||||
std::string name;
|
||||
std::vector<std::unique_ptr<RegionT>> regions;
|
||||
std::unique_ptr<TensorQuantInfoT> quantInfo;
|
||||
TensorDescribeT()
|
||||
: index(0) {
|
||||
}
|
||||
|
@ -3997,7 +4030,8 @@ struct TensorDescribe FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VT_BLOB = 4,
|
||||
VT_INDEX = 6,
|
||||
VT_NAME = 8,
|
||||
VT_REGIONS = 10
|
||||
VT_REGIONS = 10,
|
||||
VT_QUANTINFO = 12
|
||||
};
|
||||
const Blob *blob() const {
|
||||
return GetPointer<const Blob *>(VT_BLOB);
|
||||
|
@ -4011,6 +4045,9 @@ struct TensorDescribe FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::Vector<flatbuffers::Offset<Region>> *regions() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<Region>> *>(VT_REGIONS);
|
||||
}
|
||||
const TensorQuantInfo *quantInfo() const {
|
||||
return GetPointer<const TensorQuantInfo *>(VT_QUANTINFO);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_BLOB) &&
|
||||
|
@ -4021,6 +4058,8 @@ struct TensorDescribe FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VerifyOffset(verifier, VT_REGIONS) &&
|
||||
verifier.VerifyVector(regions()) &&
|
||||
verifier.VerifyVectorOfTables(regions()) &&
|
||||
VerifyOffset(verifier, VT_QUANTINFO) &&
|
||||
verifier.VerifyTable(quantInfo()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
TensorDescribeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
|
@ -4043,6 +4082,9 @@ struct TensorDescribeBuilder {
|
|||
void add_regions(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Region>>> regions) {
|
||||
fbb_.AddOffset(TensorDescribe::VT_REGIONS, regions);
|
||||
}
|
||||
void add_quantInfo(flatbuffers::Offset<TensorQuantInfo> quantInfo) {
|
||||
fbb_.AddOffset(TensorDescribe::VT_QUANTINFO, quantInfo);
|
||||
}
|
||||
explicit TensorDescribeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
|
@ -4060,8 +4102,10 @@ inline flatbuffers::Offset<TensorDescribe> CreateTensorDescribe(
|
|||
flatbuffers::Offset<Blob> blob = 0,
|
||||
int32_t index = 0,
|
||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Region>>> regions = 0) {
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Region>>> regions = 0,
|
||||
flatbuffers::Offset<TensorQuantInfo> quantInfo = 0) {
|
||||
TensorDescribeBuilder builder_(_fbb);
|
||||
builder_.add_quantInfo(quantInfo);
|
||||
builder_.add_regions(regions);
|
||||
builder_.add_name(name);
|
||||
builder_.add_index(index);
|
||||
|
@ -4074,7 +4118,8 @@ inline flatbuffers::Offset<TensorDescribe> CreateTensorDescribeDirect(
|
|||
flatbuffers::Offset<Blob> blob = 0,
|
||||
int32_t index = 0,
|
||||
const char *name = nullptr,
|
||||
const std::vector<flatbuffers::Offset<Region>> *regions = nullptr) {
|
||||
const std::vector<flatbuffers::Offset<Region>> *regions = nullptr,
|
||||
flatbuffers::Offset<TensorQuantInfo> quantInfo = 0) {
|
||||
auto name__ = name ? _fbb.CreateString(name) : 0;
|
||||
auto regions__ = regions ? _fbb.CreateVector<flatbuffers::Offset<Region>>(*regions) : 0;
|
||||
return MNN::CreateTensorDescribe(
|
||||
|
@ -4082,7 +4127,8 @@ inline flatbuffers::Offset<TensorDescribe> CreateTensorDescribeDirect(
|
|||
blob,
|
||||
index,
|
||||
name__,
|
||||
regions__);
|
||||
regions__,
|
||||
quantInfo);
|
||||
}
|
||||
|
||||
flatbuffers::Offset<TensorDescribe> CreateTensorDescribe(flatbuffers::FlatBufferBuilder &_fbb, const TensorDescribeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
@ -4094,6 +4140,7 @@ struct SubGraphProtoT : public flatbuffers::NativeTable {
|
|||
std::vector<int32_t> outputs;
|
||||
std::vector<std::string> tensors;
|
||||
std::vector<std::unique_ptr<OpT>> nodes;
|
||||
std::vector<std::unique_ptr<TensorDescribeT>> extraTensorDescribe;
|
||||
SubGraphProtoT() {
|
||||
}
|
||||
};
|
||||
|
@ -4108,7 +4155,8 @@ struct SubGraphProto FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VT_INPUTS = 6,
|
||||
VT_OUTPUTS = 8,
|
||||
VT_TENSORS = 10,
|
||||
VT_NODES = 12
|
||||
VT_NODES = 12,
|
||||
VT_EXTRATENSORDESCRIBE = 14
|
||||
};
|
||||
const flatbuffers::String *name() const {
|
||||
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
||||
|
@ -4125,6 +4173,9 @@ struct SubGraphProto FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::Vector<flatbuffers::Offset<Op>> *nodes() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<Op>> *>(VT_NODES);
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>> *extraTensorDescribe() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>> *>(VT_EXTRATENSORDESCRIBE);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_NAME) &&
|
||||
|
@ -4139,6 +4190,9 @@ struct SubGraphProto FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VerifyOffset(verifier, VT_NODES) &&
|
||||
verifier.VerifyVector(nodes()) &&
|
||||
verifier.VerifyVectorOfTables(nodes()) &&
|
||||
VerifyOffset(verifier, VT_EXTRATENSORDESCRIBE) &&
|
||||
verifier.VerifyVector(extraTensorDescribe()) &&
|
||||
verifier.VerifyVectorOfTables(extraTensorDescribe()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
SubGraphProtoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
|
@ -4164,6 +4218,9 @@ struct SubGraphProtoBuilder {
|
|||
void add_nodes(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Op>>> nodes) {
|
||||
fbb_.AddOffset(SubGraphProto::VT_NODES, nodes);
|
||||
}
|
||||
void add_extraTensorDescribe(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>>> extraTensorDescribe) {
|
||||
fbb_.AddOffset(SubGraphProto::VT_EXTRATENSORDESCRIBE, extraTensorDescribe);
|
||||
}
|
||||
explicit SubGraphProtoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
|
@ -4182,8 +4239,10 @@ inline flatbuffers::Offset<SubGraphProto> CreateSubGraphProto(
|
|||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> inputs = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> outputs = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> tensors = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Op>>> nodes = 0) {
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Op>>> nodes = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>>> extraTensorDescribe = 0) {
|
||||
SubGraphProtoBuilder builder_(_fbb);
|
||||
builder_.add_extraTensorDescribe(extraTensorDescribe);
|
||||
builder_.add_nodes(nodes);
|
||||
builder_.add_tensors(tensors);
|
||||
builder_.add_outputs(outputs);
|
||||
|
@ -4198,23 +4257,131 @@ inline flatbuffers::Offset<SubGraphProto> CreateSubGraphProtoDirect(
|
|||
const std::vector<int32_t> *inputs = nullptr,
|
||||
const std::vector<int32_t> *outputs = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *tensors = nullptr,
|
||||
const std::vector<flatbuffers::Offset<Op>> *nodes = nullptr) {
|
||||
const std::vector<flatbuffers::Offset<Op>> *nodes = nullptr,
|
||||
const std::vector<flatbuffers::Offset<TensorDescribe>> *extraTensorDescribe = nullptr) {
|
||||
auto name__ = name ? _fbb.CreateString(name) : 0;
|
||||
auto inputs__ = inputs ? _fbb.CreateVector<int32_t>(*inputs) : 0;
|
||||
auto outputs__ = outputs ? _fbb.CreateVector<int32_t>(*outputs) : 0;
|
||||
auto tensors__ = tensors ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*tensors) : 0;
|
||||
auto nodes__ = nodes ? _fbb.CreateVector<flatbuffers::Offset<Op>>(*nodes) : 0;
|
||||
auto extraTensorDescribe__ = extraTensorDescribe ? _fbb.CreateVector<flatbuffers::Offset<TensorDescribe>>(*extraTensorDescribe) : 0;
|
||||
return MNN::CreateSubGraphProto(
|
||||
_fbb,
|
||||
name__,
|
||||
inputs__,
|
||||
outputs__,
|
||||
tensors__,
|
||||
nodes__);
|
||||
nodes__,
|
||||
extraTensorDescribe__);
|
||||
}
|
||||
|
||||
flatbuffers::Offset<SubGraphProto> CreateSubGraphProto(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphProtoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct TensorQuantInfoT : public flatbuffers::NativeTable {
|
||||
typedef TensorQuantInfo TableType;
|
||||
float scale;
|
||||
float zero;
|
||||
float min;
|
||||
float max;
|
||||
DataType type;
|
||||
TensorQuantInfoT()
|
||||
: scale(0.0f),
|
||||
zero(0.0f),
|
||||
min(-128.0f),
|
||||
max(127.0f),
|
||||
type(DataType_DT_INVALID) {
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef TensorQuantInfoT NativeTableType;
|
||||
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
|
||||
return TensorQuantInfoTypeTable();
|
||||
}
|
||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||
VT_SCALE = 4,
|
||||
VT_ZERO = 6,
|
||||
VT_MIN = 8,
|
||||
VT_MAX = 10,
|
||||
VT_TYPE = 12
|
||||
};
|
||||
float scale() const {
|
||||
return GetField<float>(VT_SCALE, 0.0f);
|
||||
}
|
||||
float zero() const {
|
||||
return GetField<float>(VT_ZERO, 0.0f);
|
||||
}
|
||||
float min() const {
|
||||
return GetField<float>(VT_MIN, -128.0f);
|
||||
}
|
||||
float max() const {
|
||||
return GetField<float>(VT_MAX, 127.0f);
|
||||
}
|
||||
DataType type() const {
|
||||
return static_cast<DataType>(GetField<int32_t>(VT_TYPE, 0));
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<float>(verifier, VT_SCALE) &&
|
||||
VerifyField<float>(verifier, VT_ZERO) &&
|
||||
VerifyField<float>(verifier, VT_MIN) &&
|
||||
VerifyField<float>(verifier, VT_MAX) &&
|
||||
VerifyField<int32_t>(verifier, VT_TYPE) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
TensorQuantInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(TensorQuantInfoT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<TensorQuantInfo> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorQuantInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct TensorQuantInfoBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_scale(float scale) {
|
||||
fbb_.AddElement<float>(TensorQuantInfo::VT_SCALE, scale, 0.0f);
|
||||
}
|
||||
void add_zero(float zero) {
|
||||
fbb_.AddElement<float>(TensorQuantInfo::VT_ZERO, zero, 0.0f);
|
||||
}
|
||||
void add_min(float min) {
|
||||
fbb_.AddElement<float>(TensorQuantInfo::VT_MIN, min, -128.0f);
|
||||
}
|
||||
void add_max(float max) {
|
||||
fbb_.AddElement<float>(TensorQuantInfo::VT_MAX, max, 127.0f);
|
||||
}
|
||||
void add_type(DataType type) {
|
||||
fbb_.AddElement<int32_t>(TensorQuantInfo::VT_TYPE, static_cast<int32_t>(type), 0);
|
||||
}
|
||||
explicit TensorQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
TensorQuantInfoBuilder &operator=(const TensorQuantInfoBuilder &);
|
||||
flatbuffers::Offset<TensorQuantInfo> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<TensorQuantInfo>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<TensorQuantInfo> CreateTensorQuantInfo(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
float scale = 0.0f,
|
||||
float zero = 0.0f,
|
||||
float min = -128.0f,
|
||||
float max = 127.0f,
|
||||
DataType type = DataType_DT_INVALID) {
|
||||
TensorQuantInfoBuilder builder_(_fbb);
|
||||
builder_.add_type(type);
|
||||
builder_.add_max(max);
|
||||
builder_.add_min(min);
|
||||
builder_.add_zero(zero);
|
||||
builder_.add_scale(scale);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<TensorQuantInfo> CreateTensorQuantInfo(flatbuffers::FlatBufferBuilder &_fbb, const TensorQuantInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct NetT : public flatbuffers::NativeTable {
|
||||
typedef Net TableType;
|
||||
std::string bizCode;
|
||||
|
@ -4715,6 +4882,7 @@ inline void TensorDescribe::UnPackTo(TensorDescribeT *_o, const flatbuffers::res
|
|||
{ auto _e = index(); _o->index = _e; };
|
||||
{ auto _e = name(); if (_e) _o->name = _e->str(); };
|
||||
{ auto _e = regions(); if (_e) { _o->regions.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regions[_i] = std::unique_ptr<RegionT>(_e->Get(_i)->UnPack(_resolver)); } } };
|
||||
{ auto _e = quantInfo(); if (_e) _o->quantInfo = std::unique_ptr<TensorQuantInfoT>(_e->UnPack(_resolver)); };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<TensorDescribe> TensorDescribe::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorDescribeT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
|
@ -4729,12 +4897,14 @@ inline flatbuffers::Offset<TensorDescribe> CreateTensorDescribe(flatbuffers::Fla
|
|||
auto _index = _o->index;
|
||||
auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name);
|
||||
auto _regions = _o->regions.size() ? _fbb.CreateVector<flatbuffers::Offset<Region>> (_o->regions.size(), [](size_t i, _VectorArgs *__va) { return CreateRegion(*__va->__fbb, __va->__o->regions[i].get(), __va->__rehasher); }, &_va ) : 0;
|
||||
auto _quantInfo = _o->quantInfo ? CreateTensorQuantInfo(_fbb, _o->quantInfo.get(), _rehasher) : 0;
|
||||
return MNN::CreateTensorDescribe(
|
||||
_fbb,
|
||||
_blob,
|
||||
_index,
|
||||
_name,
|
||||
_regions);
|
||||
_regions,
|
||||
_quantInfo);
|
||||
}
|
||||
|
||||
inline SubGraphProtoT *SubGraphProto::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
|
@ -4751,6 +4921,7 @@ inline void SubGraphProto::UnPackTo(SubGraphProtoT *_o, const flatbuffers::resol
|
|||
{ auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } };
|
||||
{ auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tensors[_i] = _e->Get(_i)->str(); } } };
|
||||
{ auto _e = nodes(); if (_e) { _o->nodes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->nodes[_i] = std::unique_ptr<OpT>(_e->Get(_i)->UnPack(_resolver)); } } };
|
||||
{ auto _e = extraTensorDescribe(); if (_e) { _o->extraTensorDescribe.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->extraTensorDescribe[_i] = std::unique_ptr<TensorDescribeT>(_e->Get(_i)->UnPack(_resolver)); } } };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<SubGraphProto> SubGraphProto::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphProtoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
|
@ -4766,13 +4937,53 @@ inline flatbuffers::Offset<SubGraphProto> CreateSubGraphProto(flatbuffers::FlatB
|
|||
auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0;
|
||||
auto _tensors = _o->tensors.size() ? _fbb.CreateVectorOfStrings(_o->tensors) : 0;
|
||||
auto _nodes = _o->nodes.size() ? _fbb.CreateVector<flatbuffers::Offset<Op>> (_o->nodes.size(), [](size_t i, _VectorArgs *__va) { return CreateOp(*__va->__fbb, __va->__o->nodes[i].get(), __va->__rehasher); }, &_va ) : 0;
|
||||
auto _extraTensorDescribe = _o->extraTensorDescribe.size() ? _fbb.CreateVector<flatbuffers::Offset<TensorDescribe>> (_o->extraTensorDescribe.size(), [](size_t i, _VectorArgs *__va) { return CreateTensorDescribe(*__va->__fbb, __va->__o->extraTensorDescribe[i].get(), __va->__rehasher); }, &_va ) : 0;
|
||||
return MNN::CreateSubGraphProto(
|
||||
_fbb,
|
||||
_name,
|
||||
_inputs,
|
||||
_outputs,
|
||||
_tensors,
|
||||
_nodes);
|
||||
_nodes,
|
||||
_extraTensorDescribe);
|
||||
}
|
||||
|
||||
inline TensorQuantInfoT *TensorQuantInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new TensorQuantInfoT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void TensorQuantInfo::UnPackTo(TensorQuantInfoT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
{ auto _e = scale(); _o->scale = _e; };
|
||||
{ auto _e = zero(); _o->zero = _e; };
|
||||
{ auto _e = min(); _o->min = _e; };
|
||||
{ auto _e = max(); _o->max = _e; };
|
||||
{ auto _e = type(); _o->type = _e; };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<TensorQuantInfo> TensorQuantInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorQuantInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateTensorQuantInfo(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<TensorQuantInfo> CreateTensorQuantInfo(flatbuffers::FlatBufferBuilder &_fbb, const TensorQuantInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TensorQuantInfoT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
auto _scale = _o->scale;
|
||||
auto _zero = _o->zero;
|
||||
auto _min = _o->min;
|
||||
auto _max = _o->max;
|
||||
auto _type = _o->type;
|
||||
return MNN::CreateTensorQuantInfo(
|
||||
_fbb,
|
||||
_scale,
|
||||
_zero,
|
||||
_min,
|
||||
_max,
|
||||
_type);
|
||||
}
|
||||
|
||||
inline NetT *Net::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
|
@ -5196,6 +5407,10 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj,
|
|||
auto ptr = reinterpret_cast<const LSTMBlockCell *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case OpParameter_GridSample: {
|
||||
auto ptr = reinterpret_cast<const GridSample *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
@ -5574,6 +5789,10 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f
|
|||
auto ptr = reinterpret_cast<const LSTMBlockCell *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case OpParameter_GridSample: {
|
||||
auto ptr = reinterpret_cast<const GridSample *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -5940,6 +6159,10 @@ inline flatbuffers::Offset<void> OpParameterUnion::Pack(flatbuffers::FlatBufferB
|
|||
auto ptr = reinterpret_cast<const LSTMBlockCellT *>(value);
|
||||
return CreateLSTMBlockCell(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case OpParameter_GridSample: {
|
||||
auto ptr = reinterpret_cast<const GridSampleT *>(value);
|
||||
return CreateGridSample(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
@ -6306,6 +6529,10 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS
|
|||
value = new LSTMBlockCellT(*reinterpret_cast<LSTMBlockCellT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case OpParameter_GridSample: {
|
||||
value = new GridSampleT(*reinterpret_cast<GridSampleT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -6763,6 +6990,11 @@ inline void OpParameterUnion::Reset() {
|
|||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case OpParameter_GridSample: {
|
||||
auto ptr = reinterpret_cast<GridSampleT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -6929,12 +7161,14 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
|
|||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
OpTypeTypeTable
|
||||
};
|
||||
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603 };
|
||||
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
|
||||
static const char * const names[] = {
|
||||
"AbsVal",
|
||||
"QuantizedAdd",
|
||||
|
@ -7072,6 +7306,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
|
|||
"TensorArraySplit",
|
||||
"TensorArrayConcat",
|
||||
"LSTMBlockCell",
|
||||
"Reverse",
|
||||
"Plugin",
|
||||
"Select",
|
||||
"ZerosLike",
|
||||
|
@ -7094,10 +7329,11 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
|
|||
"EltwiseInt8",
|
||||
"While",
|
||||
"If",
|
||||
"LayerNorm"
|
||||
"LayerNorm",
|
||||
"GridSample"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_ENUM, 159, type_codes, type_refs, values, names
|
||||
flatbuffers::ST_ENUM, 161, type_codes, type_refs, values, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
@ -7194,7 +7430,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
|
|||
{ flatbuffers::ET_SEQUENCE, 0, 86 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 87 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 88 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 89 }
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 89 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 90 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
QuantizedAddTypeTable,
|
||||
|
@ -7286,7 +7523,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
|
|||
RandomUniformTypeTable,
|
||||
LayerNormTypeTable,
|
||||
TensorArrayTypeTable,
|
||||
LSTMBlockCellTypeTable
|
||||
LSTMBlockCellTypeTable,
|
||||
GridSampleTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"NONE",
|
||||
|
@ -7379,10 +7617,11 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
|
|||
"RandomUniform",
|
||||
"LayerNorm",
|
||||
"TensorArray",
|
||||
"LSTMBlockCell"
|
||||
"LSTMBlockCell",
|
||||
"GridSample"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_UNION, 91, type_codes, type_refs, nullptr, names
|
||||
flatbuffers::ST_UNION, 92, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
@ -7602,20 +7841,23 @@ inline const flatbuffers::TypeTable *TensorDescribeTypeTable() {
|
|||
{ flatbuffers::ET_SEQUENCE, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, -1 },
|
||||
{ flatbuffers::ET_STRING, 0, -1 },
|
||||
{ flatbuffers::ET_SEQUENCE, 1, 1 }
|
||||
{ flatbuffers::ET_SEQUENCE, 1, 1 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 2 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
BlobTypeTable,
|
||||
RegionTypeTable
|
||||
RegionTypeTable,
|
||||
TensorQuantInfoTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"blob",
|
||||
"index",
|
||||
"name",
|
||||
"regions"
|
||||
"regions",
|
||||
"quantInfo"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_TABLE, 4, type_codes, type_refs, nullptr, names
|
||||
flatbuffers::ST_TABLE, 5, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
@ -7626,17 +7868,44 @@ inline const flatbuffers::TypeTable *SubGraphProtoTypeTable() {
|
|||
{ flatbuffers::ET_INT, 1, -1 },
|
||||
{ flatbuffers::ET_INT, 1, -1 },
|
||||
{ flatbuffers::ET_STRING, 1, -1 },
|
||||
{ flatbuffers::ET_SEQUENCE, 1, 0 }
|
||||
{ flatbuffers::ET_SEQUENCE, 1, 0 },
|
||||
{ flatbuffers::ET_SEQUENCE, 1, 1 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
OpTypeTable
|
||||
OpTypeTable,
|
||||
TensorDescribeTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"name",
|
||||
"inputs",
|
||||
"outputs",
|
||||
"tensors",
|
||||
"nodes"
|
||||
"nodes",
|
||||
"extraTensorDescribe"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_TABLE, 6, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *TensorQuantInfoTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_FLOAT, 0, -1 },
|
||||
{ flatbuffers::ET_FLOAT, 0, -1 },
|
||||
{ flatbuffers::ET_FLOAT, 0, -1 },
|
||||
{ flatbuffers::ET_FLOAT, 0, -1 },
|
||||
{ flatbuffers::ET_INT, 0, 0 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
DataTypeTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"scale",
|
||||
"zero",
|
||||
"min",
|
||||
"max",
|
||||
"type"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_TABLE, 5, type_codes, type_refs, nullptr, names
|
||||
|
|
|
@ -374,11 +374,12 @@ enum UnaryOpOperation {
|
|||
UnaryOpOperation_EXPM1 = 28,
|
||||
UnaryOpOperation_SIGMOID = 29,
|
||||
UnaryOpOperation_TANH = 30,
|
||||
UnaryOpOperation_HARDSWISH = 31,
|
||||
UnaryOpOperation_MIN = UnaryOpOperation_ABS,
|
||||
UnaryOpOperation_MAX = UnaryOpOperation_TANH
|
||||
UnaryOpOperation_MAX = UnaryOpOperation_HARDSWISH
|
||||
};
|
||||
|
||||
inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[31] {
|
||||
inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[32] {
|
||||
static const UnaryOpOperation values[] = {
|
||||
UnaryOpOperation_ABS,
|
||||
UnaryOpOperation_NEG,
|
||||
|
@ -410,7 +411,8 @@ inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[31] {
|
|||
UnaryOpOperation_ERFINV,
|
||||
UnaryOpOperation_EXPM1,
|
||||
UnaryOpOperation_SIGMOID,
|
||||
UnaryOpOperation_TANH
|
||||
UnaryOpOperation_TANH,
|
||||
UnaryOpOperation_HARDSWISH
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
@ -448,13 +450,14 @@ inline const char * const *EnumNamesUnaryOpOperation() {
|
|||
"EXPM1",
|
||||
"SIGMOID",
|
||||
"TANH",
|
||||
"HARDSWISH",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameUnaryOpOperation(UnaryOpOperation e) {
|
||||
if (e < UnaryOpOperation_ABS || e > UnaryOpOperation_TANH) return "";
|
||||
if (e < UnaryOpOperation_ABS || e > UnaryOpOperation_HARDSWISH) return "";
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesUnaryOpOperation()[index];
|
||||
}
|
||||
|
@ -4981,6 +4984,7 @@ inline const flatbuffers::TypeTable *UnaryOpOperationTypeTable() {
|
|||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
|
@ -5017,10 +5021,11 @@ inline const flatbuffers::TypeTable *UnaryOpOperationTypeTable() {
|
|||
"ERFINV",
|
||||
"EXPM1",
|
||||
"SIGMOID",
|
||||
"TANH"
|
||||
"TANH",
|
||||
"HARDSWISH"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_ENUM, 31, type_codes, type_refs, nullptr, names
|
||||
flatbuffers::ST_ENUM, 32, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
|
|
@ -13,8 +13,76 @@ namespace MNN {
|
|||
struct TensorConvertInfo;
|
||||
struct TensorConvertInfoT;
|
||||
|
||||
struct GridSample;
|
||||
struct GridSampleT;
|
||||
|
||||
inline const flatbuffers::TypeTable *TensorConvertInfoTypeTable();
|
||||
|
||||
inline const flatbuffers::TypeTable *GridSampleTypeTable();
|
||||
|
||||
enum SampleMode {
|
||||
SampleMode_BILINEAR = 0,
|
||||
SampleMode_NEAREST = 1,
|
||||
SampleMode_MIN = SampleMode_BILINEAR,
|
||||
SampleMode_MAX = SampleMode_NEAREST
|
||||
};
|
||||
|
||||
inline const SampleMode (&EnumValuesSampleMode())[2] {
|
||||
static const SampleMode values[] = {
|
||||
SampleMode_BILINEAR,
|
||||
SampleMode_NEAREST
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesSampleMode() {
|
||||
static const char * const names[] = {
|
||||
"BILINEAR",
|
||||
"NEAREST",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameSampleMode(SampleMode e) {
|
||||
if (e < SampleMode_BILINEAR || e > SampleMode_NEAREST) return "";
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesSampleMode()[index];
|
||||
}
|
||||
|
||||
enum BorderMode {
|
||||
BorderMode_ZEROS = 0,
|
||||
BorderMode_CLAMP = 1,
|
||||
BorderMode_REFLECTION = 2,
|
||||
BorderMode_MIN = BorderMode_ZEROS,
|
||||
BorderMode_MAX = BorderMode_REFLECTION
|
||||
};
|
||||
|
||||
inline const BorderMode (&EnumValuesBorderMode())[3] {
|
||||
static const BorderMode values[] = {
|
||||
BorderMode_ZEROS,
|
||||
BorderMode_CLAMP,
|
||||
BorderMode_REFLECTION
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBorderMode() {
|
||||
static const char * const names[] = {
|
||||
"ZEROS",
|
||||
"CLAMP",
|
||||
"REFLECTION",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBorderMode(BorderMode e) {
|
||||
if (e < BorderMode_ZEROS || e > BorderMode_REFLECTION) return "";
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesBorderMode()[index];
|
||||
}
|
||||
|
||||
struct TensorConvertInfoT : public flatbuffers::NativeTable {
|
||||
typedef TensorConvertInfo TableType;
|
||||
MNN_DATA_FORMAT source;
|
||||
|
@ -84,6 +152,87 @@ inline flatbuffers::Offset<TensorConvertInfo> CreateTensorConvertInfo(
|
|||
|
||||
flatbuffers::Offset<TensorConvertInfo> CreateTensorConvertInfo(flatbuffers::FlatBufferBuilder &_fbb, const TensorConvertInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct GridSampleT : public flatbuffers::NativeTable {
|
||||
typedef GridSample TableType;
|
||||
SampleMode mode;
|
||||
BorderMode paddingMode;
|
||||
bool alignCorners;
|
||||
GridSampleT()
|
||||
: mode(SampleMode_BILINEAR),
|
||||
paddingMode(BorderMode_ZEROS),
|
||||
alignCorners(false) {
|
||||
}
|
||||
};
|
||||
|
||||
struct GridSample FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef GridSampleT NativeTableType;
|
||||
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
|
||||
return GridSampleTypeTable();
|
||||
}
|
||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||
VT_MODE = 4,
|
||||
VT_PADDINGMODE = 6,
|
||||
VT_ALIGNCORNERS = 8
|
||||
};
|
||||
SampleMode mode() const {
|
||||
return static_cast<SampleMode>(GetField<int8_t>(VT_MODE, 0));
|
||||
}
|
||||
BorderMode paddingMode() const {
|
||||
return static_cast<BorderMode>(GetField<int8_t>(VT_PADDINGMODE, 0));
|
||||
}
|
||||
bool alignCorners() const {
|
||||
return GetField<uint8_t>(VT_ALIGNCORNERS, 0) != 0;
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<int8_t>(verifier, VT_MODE) &&
|
||||
VerifyField<int8_t>(verifier, VT_PADDINGMODE) &&
|
||||
VerifyField<uint8_t>(verifier, VT_ALIGNCORNERS) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
GridSampleT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(GridSampleT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<GridSample> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GridSampleT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct GridSampleBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_mode(SampleMode mode) {
|
||||
fbb_.AddElement<int8_t>(GridSample::VT_MODE, static_cast<int8_t>(mode), 0);
|
||||
}
|
||||
void add_paddingMode(BorderMode paddingMode) {
|
||||
fbb_.AddElement<int8_t>(GridSample::VT_PADDINGMODE, static_cast<int8_t>(paddingMode), 0);
|
||||
}
|
||||
void add_alignCorners(bool alignCorners) {
|
||||
fbb_.AddElement<uint8_t>(GridSample::VT_ALIGNCORNERS, static_cast<uint8_t>(alignCorners), 0);
|
||||
}
|
||||
explicit GridSampleBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
GridSampleBuilder &operator=(const GridSampleBuilder &);
|
||||
flatbuffers::Offset<GridSample> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<GridSample>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<GridSample> CreateGridSample(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
SampleMode mode = SampleMode_BILINEAR,
|
||||
BorderMode paddingMode = BorderMode_ZEROS,
|
||||
bool alignCorners = false) {
|
||||
GridSampleBuilder builder_(_fbb);
|
||||
builder_.add_alignCorners(alignCorners);
|
||||
builder_.add_paddingMode(paddingMode);
|
||||
builder_.add_mode(mode);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<GridSample> CreateGridSample(flatbuffers::FlatBufferBuilder &_fbb, const GridSampleT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
inline TensorConvertInfoT *TensorConvertInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new TensorConvertInfoT();
|
||||
UnPackTo(_o, _resolver);
|
||||
|
@ -113,6 +262,76 @@ inline flatbuffers::Offset<TensorConvertInfo> CreateTensorConvertInfo(flatbuffer
|
|||
_dest);
|
||||
}
|
||||
|
||||
inline GridSampleT *GridSample::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new GridSampleT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void GridSample::UnPackTo(GridSampleT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
{ auto _e = mode(); _o->mode = _e; };
|
||||
{ auto _e = paddingMode(); _o->paddingMode = _e; };
|
||||
{ auto _e = alignCorners(); _o->alignCorners = _e; };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<GridSample> GridSample::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GridSampleT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateGridSample(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<GridSample> CreateGridSample(flatbuffers::FlatBufferBuilder &_fbb, const GridSampleT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GridSampleT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
auto _mode = _o->mode;
|
||||
auto _paddingMode = _o->paddingMode;
|
||||
auto _alignCorners = _o->alignCorners;
|
||||
return MNN::CreateGridSample(
|
||||
_fbb,
|
||||
_mode,
|
||||
_paddingMode,
|
||||
_alignCorners);
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *SampleModeTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_CHAR, 0, 0 },
|
||||
{ flatbuffers::ET_CHAR, 0, 0 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
SampleModeTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"BILINEAR",
|
||||
"NEAREST"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_ENUM, 2, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *BorderModeTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_CHAR, 0, 0 },
|
||||
{ flatbuffers::ET_CHAR, 0, 0 },
|
||||
{ flatbuffers::ET_CHAR, 0, 0 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
BorderModeTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"ZEROS",
|
||||
"CLAMP",
|
||||
"REFLECTION"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_ENUM, 3, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *TensorConvertInfoTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_CHAR, 0, 0 },
|
||||
|
@ -131,6 +350,27 @@ inline const flatbuffers::TypeTable *TensorConvertInfoTypeTable() {
|
|||
return &tt;
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *GridSampleTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_CHAR, 0, 0 },
|
||||
{ flatbuffers::ET_CHAR, 0, 1 },
|
||||
{ flatbuffers::ET_BOOL, 0, -1 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
SampleModeTypeTable,
|
||||
BorderModeTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"mode",
|
||||
"paddingMode",
|
||||
"alignCorners"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_TABLE, 3, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif // FLATBUFFERS_GENERATED_USERDEFINE_MNN_H_
|
||||
|
|
|
@ -153,6 +153,7 @@ enum OpType : int {
|
|||
TensorArraySplit = 139,
|
||||
TensorArrayConcat = 140,
|
||||
LSTMBlockCell = 141,
|
||||
Reverse = 142,
|
||||
|
||||
Plugin = 256, //The Type load from plugin
|
||||
//Training Op Start from 257
|
||||
|
@ -183,6 +184,7 @@ enum OpType : int {
|
|||
While = 600,
|
||||
If = 601,
|
||||
LayerNorm = 603,
|
||||
GridSample = 604,
|
||||
}
|
||||
|
||||
table Plugin {
|
||||
|
@ -328,6 +330,7 @@ union OpParameter {
|
|||
LayerNorm,
|
||||
TensorArray,
|
||||
LSTMBlockCell,
|
||||
GridSample,
|
||||
}
|
||||
|
||||
table Op {
|
||||
|
@ -356,6 +359,7 @@ table TensorDescribe {
|
|||
index: int;
|
||||
name: string;
|
||||
regions:[Region];
|
||||
quantInfo:TensorQuantInfo;
|
||||
}
|
||||
|
||||
enum ForwardType : byte {
|
||||
|
@ -387,6 +391,17 @@ table SubGraphProto {
|
|||
|
||||
// Nodes of the subgraph.
|
||||
nodes: [Op];
|
||||
|
||||
// Tensor describe info
|
||||
extraTensorDescribe:[TensorDescribe];
|
||||
}
|
||||
|
||||
table TensorQuantInfo {
|
||||
scale:float;
|
||||
zero:float = 0;
|
||||
min:float = -128;
|
||||
max:float = 127;
|
||||
type:DataType;
|
||||
}
|
||||
|
||||
table Net {
|
||||
|
|
|
@ -139,6 +139,7 @@ enum UnaryOpOperation : int {
|
|||
EXPM1 = 28,
|
||||
SIGMOID = 29,
|
||||
TANH = 30,
|
||||
HARDSWISH = 31,
|
||||
}
|
||||
|
||||
table UnaryOp {
|
||||
|
|
|
@ -4,3 +4,19 @@ table TensorConvertInfo {
|
|||
source:MNN_DATA_FORMAT;
|
||||
dest:MNN_DATA_FORMAT;
|
||||
}
|
||||
|
||||
enum SampleMode : byte {
|
||||
BILINEAR=0,
|
||||
NEAREST
|
||||
}
|
||||
enum BorderMode : byte {
|
||||
ZEROS=0,
|
||||
CLAMP,
|
||||
REFLECTION
|
||||
}
|
||||
|
||||
table GridSample {
|
||||
mode:SampleMode;
|
||||
paddingMode:BorderMode;
|
||||
alignCorners:bool=false;
|
||||
}
|
||||
|
|
|
@ -5,17 +5,18 @@
|
|||
// Created by MNN on 2019/01/31.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include <algorithm>
|
||||
#include <mutex>
|
||||
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "backend/arm82/Arm82OptFunc.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "Arm82OptFunc.hpp"
|
||||
#include "Arm82Functions.hpp"
|
||||
#include "core/BufferAllocator.hpp"
|
||||
#include "core/TensorUtils.hpp"
|
||||
|
||||
#include "core/OpCommonUtils.hpp"
|
||||
#include "backend/cpu/compute/CommonOptFunction.h"
|
||||
#include "half.hpp"
|
||||
|
||||
namespace MNN {
|
||||
|
@ -37,8 +38,8 @@ bool Arm82Backend::addArm82Creator(OpType t, Arm82Creator* ct) {
|
|||
return true;
|
||||
}
|
||||
|
||||
Arm82Backend::Arm82Backend(const CPURuntime* runtime) : CPUBackend(runtime, MNN_FORWARD_CPU_EXTENSION) {
|
||||
// nothing to do
|
||||
Arm82Backend::Arm82Backend(const CPURuntime* runtime) : CPUBackend(runtime, BackendConfig::Precision_Low, MNN_FORWARD_CPU_EXTENSION) {
|
||||
mCoreFunctions = Arm82Functions::get();
|
||||
}
|
||||
|
||||
Arm82Backend::~Arm82Backend() {
|
||||
|
@ -52,6 +53,14 @@ Execution* Arm82Backend::onCreate(const std::vector<Tensor*>& inputs, const std:
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto quantInfo = OpCommonUtils::getQuantInfo(inputs);
|
||||
if (quantInfo.first) {
|
||||
return nullptr;
|
||||
}
|
||||
bool originCreate = OpCommonUtils::opCompabilityForLowp(op);
|
||||
if (originCreate) {
|
||||
return CPUBackend::onCreate(inputs, outputs, op);
|
||||
}
|
||||
auto creatorContainer = getArm82CreatorContainer();
|
||||
// MNN_PRINT("====> create Execution for type: %s\n", MNN::EnumNameOpType(op->type()));
|
||||
auto iter = creatorContainer->find(op->type());
|
||||
|
@ -88,7 +97,7 @@ bool Arm82Backend::onAcquireBuffer(const Tensor* nativeTensor, StorageType stora
|
|||
// arm82 backend tensor data type is fp16 default
|
||||
auto tensor = const_cast<Tensor*>(nativeTensor);
|
||||
auto& buffer = tensor->buffer();
|
||||
if (buffer.type != halide_type_of<float>()) {
|
||||
if (buffer.type != halide_type_of<float>() && buffer.type != halide_type_of<FLOAT16>()) {
|
||||
return CPUBackend::onAcquireBuffer(nativeTensor, storageType);
|
||||
}
|
||||
auto res = allocBuffer(_getAliginSize(buffer, TensorUtils::getDescribe(nativeTensor)->dimensionFormat), (Tensor*)nativeTensor, storageType);
|
||||
|
@ -128,7 +137,7 @@ static void _convertFp16Inside(const halide_buffer_t& ib, const halide_buffer_t&
|
|||
const int outBatchStide = channel * area;
|
||||
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
MNNNC8HW8TONCHW_NO_TYPE((uint16_t*)ob.host + outBatchStide * i, (const uint16_t*)ib.host + inbatchStride * i, area,
|
||||
MNNUnPackC8FP16((FLOAT16*)ob.host + outBatchStide * i, (const FLOAT16*)ib.host + inbatchStride * i, area,
|
||||
channel);
|
||||
}
|
||||
return;
|
||||
|
@ -138,7 +147,7 @@ static void _convertFp16Inside(const halide_buffer_t& ib, const halide_buffer_t&
|
|||
const int inbatchStride = channel * area;
|
||||
const int outBatchStide = UP_DIV(channel, ARMV82_CHANNEL_UNIT) * area * ARMV82_CHANNEL_UNIT;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
MNNNCHWTONC8HW8_NO_TYPE((uint16_t*)ob.host + outBatchStide * i, (const uint16_t*)ib.host + inbatchStride * i, area,
|
||||
MNNPackC8FP16((FLOAT16*)ob.host + outBatchStide * i, (const FLOAT16*)ib.host + inbatchStride * i, area,
|
||||
channel);
|
||||
}
|
||||
return;
|
||||
|
@ -200,14 +209,14 @@ void Arm82Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor
|
|||
const int outBatchStride = UP_DIV(channel, ARMV82_CHANNEL_UNIT) * area * ARMV82_CHANNEL_UNIT;
|
||||
const int inbatchStride = UP_DIV(channel, 4) * area * 4;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
MNNNC4HW4TONC8HW8(dstTensor->host<uint16_t>() + outBatchStride * i, srcTensor->host<float>() + inbatchStride * i, area,
|
||||
MNNNC4HW4TONC8HW8(dstTensor->host<FLOAT16>() + outBatchStride * i, srcTensor->host<float>() + inbatchStride * i, area,
|
||||
channel);
|
||||
}
|
||||
} else {
|
||||
const int inbatchStride = UP_DIV(channel, ARMV82_CHANNEL_UNIT) * area * ARMV82_CHANNEL_UNIT;
|
||||
const int outBatchStide = UP_DIV(channel, 4) * area * 4;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
MNNNC8HW8TONC4HW4(dstTensor->host<float>() + outBatchStide * i, srcTensor->host<uint16_t>() + inbatchStride * i, area,
|
||||
MNNNC8HW8TONC4HW4(dstTensor->host<float>() + outBatchStide * i, srcTensor->host<FLOAT16>() + inbatchStride * i, area,
|
||||
channel);
|
||||
}
|
||||
}
|
||||
|
@ -220,15 +229,15 @@ void Arm82Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor
|
|||
// cpu -> arm82 copy
|
||||
if (srcType == MNN_FORWARD_CPU) {
|
||||
const auto src = srcTensor->host<float>();
|
||||
auto dst = dstTensor->host<FLOAT16>();
|
||||
MNNQuantizeFP16(dst, src, elemenSize);
|
||||
auto dst = dstTensor->host<int16_t>();
|
||||
MNNQuantizeFP16(src, dst, elemenSize);
|
||||
return;
|
||||
}
|
||||
// arm82 -> cpu copy
|
||||
if (srcType == MNN_FORWARD_CPU_EXTENSION) {
|
||||
const auto src = srcTensor->host<int16_t>();
|
||||
auto dst = dstTensor->host<float>();
|
||||
MNNDequantizeFP16(dst, src, elemenSize);
|
||||
MNNDequantizeFP16(src, dst, elemenSize);
|
||||
return;
|
||||
}
|
||||
MNN_ERROR("Invalide copy for intenal Arm82 Backend\n");
|
||||
|
@ -236,6 +245,7 @@ void Arm82Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor
|
|||
}
|
||||
|
||||
void registerArm82RuntimeCreator() {
|
||||
Arm82Functions::init();
|
||||
registerArm82Ops();
|
||||
};
|
||||
#ifndef MNN_CODEGEN_REGISTER
|
||||
|
@ -246,5 +256,4 @@ static const auto __arm82_global_initializer = []() {
|
|||
#endif
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
||||
|
|
|
@ -5,19 +5,25 @@
|
|||
// Created by MNN on 2019/01/31.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Backend_hpp
|
||||
#define Arm82Backend_hpp
|
||||
|
||||
#include "backend/cpu/CPUBackend.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include <MNN/HalideRuntime.h>
|
||||
|
||||
// armv82's data type default is fp16, so set
|
||||
// armv82's dataformat: NC8HW8
|
||||
#define ARMV82_CHANNEL_UNIT 8
|
||||
|
||||
typedef __fp16 FLOAT16;
|
||||
template<>
|
||||
HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<FLOAT16>() {
|
||||
return halide_type_t(halide_type_float, 16);
|
||||
}
|
||||
|
||||
namespace MNN {
|
||||
class Arm82Backend : public CPUBackend {
|
||||
|
@ -60,8 +66,19 @@ inline int ARM82TensorElementSizeHelper(const Tensor* t) {
|
|||
return size;
|
||||
}
|
||||
|
||||
inline int ARM82TensorStrideHelper(const Tensor* t, int dim) {
|
||||
int size = 1;
|
||||
for (int i = t->dimensions() - 1; i > dim; i--) {
|
||||
int currentDimSize = t->length(i);
|
||||
if (TensorUtils::getDescribe(t)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4 && 1 == i) {
|
||||
currentDimSize = UP_DIV(currentDimSize, 8) * 8;
|
||||
}
|
||||
size *= currentDimSize;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82Backend_hpp */
|
||||
|
||||
#endif
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
// Copyright © 2021, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
#include <algorithm>
|
||||
#include "backend/arm82/Arm82Binary.hpp"
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
// Created by MNN on 2021/01/05.
|
||||
// Copyright © 2021, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Binary_hpp
|
||||
#define Arm82Binary_hpp
|
||||
|
||||
|
|
|
@ -1,471 +0,0 @@
|
|||
//
|
||||
// Arm82Convolution.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2020/01/07.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#include "backend/arm82/Arm82Convolution.hpp"
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "backend/arm82/Arm82Convolution3x3.hpp"
|
||||
#include "backend/arm82/Arm82OptFunc.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "core/ConvolutionCommon.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
|
||||
#ifndef MNN_USE_NEON
|
||||
static void MNNGemmFP16C8_UNIT(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 *weight, const FLOAT16 *bias,
|
||||
size_t src_loop, size_t dst_step, size_t dst_loop, size_t relu, size_t relu6,
|
||||
size_t realDstCount) {
|
||||
const auto dst_step_tmp = dst_step / sizeof(FLOAT16);
|
||||
|
||||
for (int dz = 0; dz < dst_loop; ++dz) {
|
||||
const auto weight_dz = weight + dz * src_loop * (ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT);
|
||||
const auto bias_dz = bias + dz * ARMV82_CHANNEL_UNIT;
|
||||
auto dst_z = dst + dz * dst_step_tmp;
|
||||
for (int w = 0; w < DST_XUNIT; ++w) {
|
||||
const auto src_x = src + w * ARMV82_CHANNEL_UNIT;
|
||||
auto dst_x = dst_z + w * ARMV82_CHANNEL_UNIT;
|
||||
FLOAT16 dstTemp[ARMV82_CHANNEL_UNIT];
|
||||
|
||||
memcpy(dstTemp, bias_dz, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
|
||||
// MAC
|
||||
for (int sz = 0; sz < src_loop; ++sz) {
|
||||
const auto weight_sz = weight_dz + (ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT) * sz;
|
||||
const auto src_z = src_x + sz * DST_XUNIT * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
||||
dstTemp[j] += src_z[i] * weight_sz[i * ARMV82_CHANNEL_UNIT + j];
|
||||
}
|
||||
}
|
||||
} // end MAC
|
||||
|
||||
if (relu) {
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
if (dstTemp[j] < 0) {
|
||||
dstTemp[j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (relu6) {
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
if (dstTemp[j] < 0) {
|
||||
dstTemp[j] = 0;
|
||||
}
|
||||
if (dstTemp[j] > 6) {
|
||||
dstTemp[j] = 6.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(dst_x, dstTemp, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static void Im2ColTransformer(FLOAT16 *dst, const FLOAT16 *src, ConvolutionCommon::Im2ColParameter *im2colParam,
|
||||
size_t xIndexStart, size_t realDstCount) {
|
||||
{
|
||||
const int colBufferSize = im2colParam->kernelCountUnit * DST_XUNIT * ARMV82_CHANNEL_UNIT * sizeof(FLOAT16);
|
||||
memset(dst, 0, colBufferSize);
|
||||
}
|
||||
// src data format is nc8hw8
|
||||
|
||||
const auto ih = im2colParam->ih;
|
||||
const auto iw = im2colParam->iw;
|
||||
// const auto oh = im2colParameter->oh;
|
||||
const auto ow = im2colParam->ow;
|
||||
const auto kh = im2colParam->kernelY;
|
||||
const auto kw = im2colParam->kernelX;
|
||||
const auto dilateX = im2colParam->dilateX;
|
||||
const auto dilateY = im2colParam->dilateY;
|
||||
const auto icDiv4 = im2colParam->icDiv4;
|
||||
const auto srcChannleStride = iw * ih * ARMV82_CHANNEL_UNIT;
|
||||
const auto stridex = im2colParam->strideX;
|
||||
const auto stridey = im2colParam->strideY;
|
||||
const auto padx = im2colParam->padX;
|
||||
const auto pady = im2colParam->padY;
|
||||
constexpr int dstXStep = ARMV82_CHANNEL_UNIT * DST_XUNIT;
|
||||
|
||||
for (int i = 0; i < realDstCount; ++i) {
|
||||
int xIndex = (int)xIndexStart + i;
|
||||
int ox = xIndex % ow;
|
||||
int oy = xIndex / ow;
|
||||
int sx = ox * stridex - padx;
|
||||
int sy = oy * stridey - pady;
|
||||
int sfy = ALIMAX(0, (UP_DIV(-sy, dilateY)));
|
||||
int efy = ALIMIN(kh, UP_DIV(ih - sy, dilateY));
|
||||
int sfx = ALIMAX(0, (UP_DIV(-sx, dilateX)));
|
||||
int efx = ALIMIN(kw, UP_DIV(iw - sx, dilateX));
|
||||
int fyC = efy - sfy;
|
||||
int fxC = efx - sfx;
|
||||
|
||||
auto colAddrI = dst + ARMV82_CHANNEL_UNIT * i;
|
||||
auto inputOffset = src + (sx + sfx * dilateX + (sy + sfy * dilateY) * iw) * ARMV82_CHANNEL_UNIT;
|
||||
auto indexOffset = (sfy * kw + sfx) * icDiv4;
|
||||
|
||||
for (int fy = 0; fy < fyC; ++fy) {
|
||||
for (int fx = 0; fx < fxC; ++fx) {
|
||||
auto inputUnit = inputOffset + (fx * dilateX + fy * dilateY * iw) * ARMV82_CHANNEL_UNIT;
|
||||
auto indexStart = (indexOffset + (fy * kw + fx) * icDiv4) * dstXStep;
|
||||
for (int sz = 0; sz < icDiv4; ++sz) {
|
||||
auto dstUnit = colAddrI + indexStart + sz * dstXStep;
|
||||
memcpy(dstUnit, inputUnit, ARMV82_CHANNEL_UNIT * sizeof(FLOAT16));
|
||||
inputUnit += srcChannleStride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shuffle channel
|
||||
#ifdef MNN_USE_NEON
|
||||
if (realDstCount > (DST_XUNIT / 2)) {
|
||||
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 0);
|
||||
} else {
|
||||
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void Im2ColTransformer1x1(FLOAT16 *dst, const FLOAT16 *src, ConvolutionCommon::Im2ColParameter *im2colParam,
|
||||
size_t xIndexStart, size_t realDstCount) {
|
||||
{
|
||||
const int colBufferSize = im2colParam->kernelCountUnit * DST_XUNIT * ARMV82_CHANNEL_UNIT * sizeof(FLOAT16);
|
||||
memset(dst, 0, colBufferSize);
|
||||
}
|
||||
// src data format is nc8hw8
|
||||
const auto ih = im2colParam->ih;
|
||||
const auto iw = im2colParam->iw;
|
||||
|
||||
const auto icDiv8 = im2colParam->icDiv4;
|
||||
const auto srcChannleStride = iw * ih * ARMV82_CHANNEL_UNIT;
|
||||
constexpr int dstXStep = ARMV82_CHANNEL_UNIT * DST_XUNIT;
|
||||
const auto srcStartPtr = src + xIndexStart * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
for (int c = 0; c < icDiv8; ++c) {
|
||||
memcpy(dst + c * dstXStep, srcStartPtr + c * srcChannleStride,
|
||||
sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT * realDstCount);
|
||||
}
|
||||
|
||||
// shuffle channel
|
||||
#ifdef MNN_USE_NEON
|
||||
if (realDstCount > (DST_XUNIT / 2)) {
|
||||
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 0);
|
||||
} else {
|
||||
MNNShuffleChannelC8(dst, dst, (size_t)im2colParam->kernelCountUnit, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Arm82Convolution::Arm82Convolution(const MNN::Convolution2D *convParam, Backend *bn) : Execution(bn) {
|
||||
const auto convCommon = convParam->common();
|
||||
mCommon = convCommon;
|
||||
const int kx = convCommon->kernelX();
|
||||
const int ky = convCommon->kernelY();
|
||||
const int kernelCount = kx * ky;
|
||||
int inputChannel = convCommon->inputCount();
|
||||
const int outputChannel = convCommon->outputCount();
|
||||
if (inputChannel == 0) {
|
||||
if (convParam->quanParameter()) {
|
||||
inputChannel = convParam->quanParameter()->buffer()->size() / (2 * kernelCount * outputChannel);
|
||||
} else {
|
||||
inputChannel = convParam->weight()->size() / (kernelCount * outputChannel);
|
||||
}
|
||||
}
|
||||
const int inputChannelUnit = UP_DIV(inputChannel, ARMV82_CHANNEL_UNIT);
|
||||
const int outputChannelUnit = UP_DIV(outputChannel, ARMV82_CHANNEL_UNIT);
|
||||
|
||||
const int totalKernelCountUnit = kernelCount * inputChannelUnit;
|
||||
mWeightFp16.reset(Tensor::createDevice<uint16_t>(
|
||||
{outputChannelUnit, totalKernelCountUnit, ARMV82_CHANNEL_UNIT, ARMV82_CHANNEL_UNIT}));
|
||||
auto allocRes = bn->onAcquireBuffer(mWeightFp16.get(), Backend::STATIC);
|
||||
if (!allocRes) {
|
||||
mValid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto weightFp16DstPtr = mWeightFp16->host<FLOAT16>();
|
||||
memset(weightFp16DstPtr, 0, mWeightFp16->size());
|
||||
|
||||
const FLOAT16 *fp16WeightPtr = nullptr;
|
||||
std::vector<FLOAT16> weightFp16;
|
||||
if (convParam->quanParameter()) {
|
||||
MNN_ASSERT((convParam->quanParameter()->type() == 3) || (convParam->quanParameter()->type() == 4));
|
||||
if (convParam->quanParameter()->type() == 3) {
|
||||
// the data type of weight is fp16
|
||||
fp16WeightPtr = reinterpret_cast<const FLOAT16 *>(convParam->quanParameter()->buffer()->data());
|
||||
}
|
||||
if (convParam->quanParameter()->type() == 4) {
|
||||
std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
|
||||
quanCommon = ConvolutionCommon::load(convParam->quanParameter(), true);
|
||||
int weightCount = convParam->quanParameter()->buffer()->size();
|
||||
weightFp16.resize(weightCount);
|
||||
MNNQuantizeFP16(weightFp16.data(), quanCommon->weightFloat.get(), weightCount);
|
||||
fp16WeightPtr = weightFp16.data();
|
||||
}
|
||||
} else {
|
||||
// the data type of weight is fp32, then quantize weight to be fp16 data type
|
||||
int size = convParam->weight()->size();
|
||||
weightFp16.resize(size);
|
||||
MNNQuantizeFP16(weightFp16.data(), convParam->weight()->data(), size);
|
||||
fp16WeightPtr = weightFp16.data();
|
||||
}
|
||||
|
||||
auto weightFp16SrcPtr = fp16WeightPtr;
|
||||
|
||||
const int oneChannleKernelSize = kernelCount * inputChannel;
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
int curOcChannel = 0;
|
||||
auto reorderWeight = [&](int ocUnit, int ocUnitNum, const FLOAT16 *weightSrc, FLOAT16 *weightDst) {
|
||||
for (int oc = 0; oc < ocUnitNum; ++oc) {
|
||||
auto weightDstOcUnit = weightDst + oc * kernelCount * inputChannelUnit * ARMV82_CHANNEL_UNIT * ocUnit;
|
||||
const auto weightSrcOc = weightSrc + oc * ocUnit * oneChannleKernelSize;
|
||||
for (int k = 0; k < kernelCount; ++k) {
|
||||
auto weightDstK = weightDstOcUnit + k * inputChannelUnit * ARMV82_CHANNEL_UNIT * ocUnit;
|
||||
const auto weightSrcK = weightSrcOc + k;
|
||||
for (int y = 0; y < inputChannel; ++y) {
|
||||
const int yOutSide = y / ARMV82_CHANNEL_UNIT;
|
||||
const int yInSide = y % ARMV82_CHANNEL_UNIT;
|
||||
auto weightDstIc = weightDstK + yOutSide * ARMV82_CHANNEL_UNIT * ocUnit + yInSide * ocUnit;
|
||||
const auto weigthSrcIc = weightSrcK + y * kernelCount;
|
||||
|
||||
for (int x = 0; x < ocUnit; ++x) {
|
||||
if (curOcChannel + x < outputChannel) {
|
||||
weightDstIc[x] = weigthSrcIc[x * oneChannleKernelSize];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
curOcChannel += ocUnit;
|
||||
}
|
||||
};
|
||||
const int ocDivDoubleUnit = outputChannelUnit / 2;
|
||||
// reorder weight in double ARMV82_CHANNEL_UNIT
|
||||
reorderWeight((ARMV82_CHANNEL_UNIT * 2), ocDivDoubleUnit, weightFp16SrcPtr, weightFp16DstPtr);
|
||||
auto weightRemainDst = weightFp16DstPtr + kernelCount * inputChannelUnit * ARMV82_CHANNEL_UNIT * ocDivDoubleUnit *
|
||||
(ARMV82_CHANNEL_UNIT * 2);
|
||||
auto weightRemainSrc = weightFp16SrcPtr + kernelCount * inputChannel * ocDivDoubleUnit * (ARMV82_CHANNEL_UNIT * 2);
|
||||
if (outputChannelUnit % 2 == 1) {
|
||||
// reorder weight in ARMV82_CHANNEL_UNIT
|
||||
reorderWeight(ARMV82_CHANNEL_UNIT, 1, weightRemainSrc, weightRemainDst);
|
||||
}
|
||||
#else
|
||||
// reorder weight
|
||||
const int ocUnitStride = inputChannelUnit * ARMV82_CHANNEL_UNIT * kernelCount * ARMV82_CHANNEL_UNIT;
|
||||
for (int k = 0; k < kernelCount; ++k) {
|
||||
const auto weightSrcK = weightFp16SrcPtr + k;
|
||||
auto weightDstK = weightFp16DstPtr + k * inputChannelUnit * ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT;
|
||||
for (int y = 0; y < inputChannel; ++y) {
|
||||
const int yOutSide = y / ARMV82_CHANNEL_UNIT;
|
||||
const int yInSide = y % ARMV82_CHANNEL_UNIT;
|
||||
|
||||
auto dstY =
|
||||
weightDstK + yOutSide * ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT + yInSide * ARMV82_CHANNEL_UNIT;
|
||||
const auto srcY = weightSrcK + y * kernelCount;
|
||||
for (int x = 0; x < outputChannel; ++x) {
|
||||
const int xOutSide = x / ARMV82_CHANNEL_UNIT;
|
||||
const int xInSide = x % ARMV82_CHANNEL_UNIT;
|
||||
const int dstIndex = xOutSide * ocUnitStride + xInSide;
|
||||
const int srcIndex = x * oneChannleKernelSize;
|
||||
dstY[dstIndex] = srcY[srcIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
mBiasFp16.reset(Tensor::createDevice<uint16_t>({outputChannelUnit * ARMV82_CHANNEL_UNIT}));
|
||||
allocRes = bn->onAcquireBuffer(mBiasFp16.get(), Backend::STATIC);
|
||||
if (!allocRes) {
|
||||
mValid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO, bias is fp32, save bias also in fp16?
|
||||
auto biasDstPtr = mBiasFp16->host<FLOAT16>();
|
||||
memset(biasDstPtr, 0, mBiasFp16->size());
|
||||
MNNQuantizeFP16(biasDstPtr, convParam->bias()->data(), outputChannel);
|
||||
|
||||
mIm2ColParamter.dilateX = convCommon->dilateX();
|
||||
mIm2ColParamter.dilateY = convCommon->dilateY();
|
||||
mIm2ColParamter.strideX = convCommon->strideX();
|
||||
mIm2ColParamter.strideY = convCommon->strideY();
|
||||
mIm2ColParamter.padX = convCommon->padX();
|
||||
mIm2ColParamter.padY = convCommon->padY();
|
||||
mIm2ColParamter.icDiv4 = inputChannelUnit;
|
||||
mIm2ColParamter.kernelX = convCommon->kernelX();
|
||||
mIm2ColParamter.kernelY = convCommon->kernelY();
|
||||
mIm2ColParamter.kernelCountUnit = totalKernelCountUnit;
|
||||
|
||||
mRelu6 = convCommon->relu6();
|
||||
mRelu = convCommon->relu();
|
||||
}
|
||||
|
||||
Arm82Convolution::~Arm82Convolution() {
|
||||
if (mWeightFp16 != nullptr) {
|
||||
backend()->onReleaseBuffer(mWeightFp16.get(), Backend::STATIC);
|
||||
}
|
||||
if (mBiasFp16 != nullptr) {
|
||||
backend()->onReleaseBuffer(mBiasFp16.get(), Backend::STATIC);
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode Arm82Convolution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
||||
mIm2ColParamter.padX = mCommon->padX();
|
||||
mIm2ColParamter.padY = mCommon->padY();
|
||||
if (mCommon->padMode() == PadMode_SAME) {
|
||||
int kernelWidthSize = (mCommon->kernelX() - 1) * mCommon->dilateX() + 1;
|
||||
int kernelHeightSize = (mCommon->kernelY() - 1) * mCommon->dilateY() + 1;
|
||||
|
||||
int padNeededWidth = (output->width() - 1) * mCommon->strideX() + kernelWidthSize - input->width();
|
||||
int padNeededHeight = (output->height() - 1) * mCommon->strideY() + kernelHeightSize - input->height();
|
||||
mIm2ColParamter.padX = padNeededWidth / 2;
|
||||
mIm2ColParamter.padY = padNeededHeight / 2;
|
||||
}
|
||||
|
||||
mIm2ColParamter.ih = input->height();
|
||||
mIm2ColParamter.iw = input->width();
|
||||
mIm2ColParamter.oh = output->height();
|
||||
mIm2ColParamter.ow = output->width();
|
||||
|
||||
mTileCount = UP_DIV(output->height() * output->width(), DST_XUNIT);
|
||||
const int threads = std::max(1, static_cast<Arm82Backend *>(backend())->numberThread());
|
||||
mThreadNums = std::min(threads, mTileCount);
|
||||
|
||||
mIm2ColBuffer.setType(DataType_DT_BFLOAT16);
|
||||
mIm2ColBuffer.buffer().dimensions = 3;
|
||||
mIm2ColBuffer.setLength(0, mThreadNums);
|
||||
mIm2ColBuffer.setLength(1, DST_XUNIT);
|
||||
mIm2ColBuffer.setLength(2, mWeightFp16->length(1) * ARMV82_CHANNEL_UNIT);
|
||||
TensorUtils::setLinearLayout(&mIm2ColBuffer);
|
||||
|
||||
mRemainBuffer.setType(DataType_DT_BFLOAT16);
|
||||
mRemainBuffer.buffer().dimensions = 3;
|
||||
mRemainBuffer.setLength(0, mThreadNums);
|
||||
mRemainBuffer.setLength(1, DST_XUNIT);
|
||||
mRemainBuffer.setLength(2, UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT) * ARMV82_CHANNEL_UNIT);
|
||||
TensorUtils::setLinearLayout(&mRemainBuffer);
|
||||
bool success = backend()->onAcquireBuffer(&mIm2ColBuffer, Backend::DYNAMIC);
|
||||
success = success && backend()->onAcquireBuffer(&mRemainBuffer, Backend::DYNAMIC);
|
||||
if (!success) {
|
||||
return OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
backend()->onReleaseBuffer(&mIm2ColBuffer, Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(&mRemainBuffer, Backend::DYNAMIC);
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode Arm82Convolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
const int outputPlaneLen = output->height() * output->width();
|
||||
|
||||
const int dstZStep = outputPlaneLen * ARMV82_CHANNEL_UNIT;
|
||||
const int batch = input->batch();
|
||||
const int ocDiv8 = UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT);
|
||||
const int kernelCountUnit = mIm2ColParamter.kernelCountUnit;
|
||||
|
||||
const auto inputDataPtr = input->host<FLOAT16>();
|
||||
const auto weightDataPtr = mWeightFp16->host<FLOAT16>();
|
||||
const auto biasDataPtr = mBiasFp16->host<FLOAT16>();
|
||||
auto im2ColPtr = mIm2ColBuffer.host<FLOAT16>();
|
||||
auto outputDataPtr = output->host<FLOAT16>();
|
||||
auto remainDataPtr = mRemainBuffer.host<FLOAT16>();
|
||||
|
||||
auto im2ColProcess = Im2ColTransformer;
|
||||
bool useFastIm2Col = mIm2ColParamter.kernelX == 1 && mIm2ColParamter.kernelY == 1 && mIm2ColParamter.strideX == 1 &&
|
||||
mIm2ColParamter.strideY == 1 && mIm2ColParamter.padX == 0 && mIm2ColParamter.padY == 0;
|
||||
|
||||
if (useFastIm2Col) {
|
||||
im2ColProcess = Im2ColTransformer1x1;
|
||||
}
|
||||
|
||||
const int inBatchStride = ROUND_UP(input->channel(), ARMV82_CHANNEL_UNIT) * input->height() * input->width();
|
||||
const int outBatchStride = ocDiv8 * dstZStep;
|
||||
for (int bIndex = 0; bIndex < batch; ++bIndex) {
|
||||
const auto srcBatchPtr = inputDataPtr + bIndex * inBatchStride;
|
||||
auto dstBatchPtr = outputDataPtr + bIndex * outBatchStride;
|
||||
|
||||
auto threadFunction = [&](int tId) {
|
||||
auto im2ColCurPtr = im2ColPtr + tId * mIm2ColBuffer.stride(0);
|
||||
auto gemmOutputPtr = remainDataPtr + tId * mRemainBuffer.stride(0);
|
||||
|
||||
for (int tIndex = tId; tIndex < mTileCount; tIndex += mThreadNums) {
|
||||
const int xIndexStart = tIndex * DST_XUNIT;
|
||||
const int realDstCount = ALIMIN(outputPlaneLen - xIndexStart, DST_XUNIT);
|
||||
|
||||
Im2ColTransformer(im2ColCurPtr, srcBatchPtr, &mIm2ColParamter, xIndexStart, realDstCount);
|
||||
|
||||
auto outputCurTilePtr = dstBatchPtr + xIndexStart * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
if (realDstCount == DST_XUNIT) {
|
||||
// compute one tile
|
||||
MNNGemmFP16C8_UNIT(outputCurTilePtr, im2ColCurPtr, weightDataPtr, biasDataPtr, kernelCountUnit,
|
||||
dstZStep * sizeof(FLOAT16), ocDiv8, mRelu, mRelu6, realDstCount);
|
||||
} else {
|
||||
// compute the remain
|
||||
MNNGemmFP16C8_UNIT(gemmOutputPtr, im2ColCurPtr, weightDataPtr, biasDataPtr, kernelCountUnit,
|
||||
ARMV82_CHANNEL_UNIT * DST_XUNIT * sizeof(FLOAT16), ocDiv8, mRelu, mRelu6,
|
||||
realDstCount);
|
||||
for (int z = 0; z < ocDiv8; ++z) {
|
||||
auto outputz = outputCurTilePtr + z * dstZStep;
|
||||
auto srcz = gemmOutputPtr + z * ARMV82_CHANNEL_UNIT * DST_XUNIT;
|
||||
memcpy(outputz, srcz, realDstCount * ARMV82_CHANNEL_UNIT * sizeof(FLOAT16));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNums)
|
||||
threadFunction((int)tId);
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
MNN_CONCURRENCY_END();
|
||||
#else
|
||||
MNN_CONCURRENCY_END();
|
||||
#endif
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
class Arm82ConvolutionCreator : public Arm82Backend::Arm82Creator {
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||
const MNN::Op *op, Backend *backend) const override {
|
||||
auto convParam = op->main_as_Convolution2D();
|
||||
// avoid other quantize method entry this creator
|
||||
if(convParam->quanParameter() && convParam->quanParameter()->type() != 3){
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
const auto param = convParam->common();
|
||||
if (param->kernelX() == 3 && param->kernelY() == 3 && param->strideX() == 1 && param->strideY() == 1 &&
|
||||
param->dilateX() == 1 && param->dilateY() == 1) {
|
||||
return new Arm82Convolution3x3(convParam, backend);
|
||||
}
|
||||
#endif
|
||||
return new Arm82Convolution(convParam, backend);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_ARM82_OP_CREATOR(OpType_Convolution, Arm82ConvolutionCreator);
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
|
@ -1,40 +0,0 @@
|
|||
//
|
||||
// Arm82Convolution.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2020/01/07.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#ifndef Arm82Convolution_hpp
|
||||
#define Arm82Convolution_hpp
|
||||
|
||||
#include "core/ConvolutionCommon.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
class Arm82Convolution : public Execution {
|
||||
public:
|
||||
Arm82Convolution(const MNN::Convolution2D *convParam, Backend *bn);
|
||||
virtual ~Arm82Convolution();
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
// plane tile number
|
||||
int mTileCount;
|
||||
int mThreadNums;
|
||||
bool mRelu;
|
||||
bool mRelu6;
|
||||
ConvolutionCommon::Im2ColParameter mIm2ColParamter;
|
||||
std::shared_ptr<Tensor> mWeightFp16;
|
||||
std::shared_ptr<Tensor> mBiasFp16;
|
||||
|
||||
Tensor mIm2ColBuffer;
|
||||
Tensor mRemainBuffer;
|
||||
const Convolution2DCommon *mCommon;
|
||||
};
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82Convolution_hpp */
|
||||
#endif
|
|
@ -1,537 +0,0 @@
|
|||
//
|
||||
// Arm82Convolution3x3.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2020/02/04.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
#include "backend/arm82/Arm82Convolution3x3.hpp"
|
||||
#include "backend/arm82/Arm82OptFunc.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "core/ConvolutionCommon.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
constexpr int CONV3X3_WINO_OUT = 4;
|
||||
constexpr int CONV3X3_WINO_KER = 3;
|
||||
constexpr int CONV3X3_WINO_IN = CONV3X3_WINO_OUT + CONV3X3_WINO_KER - 1;
|
||||
constexpr int CONV3X3_WEIGHT_UNIT = CONV3X3_WINO_IN * CONV3X3_WINO_IN * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
constexpr int CONV3X3_WINO_TILE = 8;
|
||||
constexpr int CONV3X3_WINO_SRC_NUM = CONV3X3_WINO_IN * CONV3X3_WINO_IN * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
namespace MNN {
|
||||
|
||||
// winograd F(4,3)
|
||||
#ifdef MNN_USE_NEON
|
||||
static void kernelTransform_wino_4x4_3x3(const FLOAT16* src, FLOAT16* dst, int step) {
|
||||
FLOAT16 midResult6X3[6][3];
|
||||
|
||||
for (int i = 0; i < CONV3X3_WINO_KER; ++i) {
|
||||
FLOAT16 a0i = src[i];
|
||||
FLOAT16 a1i = src[1 * CONV3X3_WINO_KER + i];
|
||||
FLOAT16 a2i = src[2 * CONV3X3_WINO_KER + i];
|
||||
|
||||
midResult6X3[0][i] = 0.25f * a0i;
|
||||
midResult6X3[1][i] = (a0i + a1i + a2i) * -0.1666666666666667f;
|
||||
midResult6X3[2][i] = (a0i - a1i + a2i) * -0.1666666666666667f;
|
||||
midResult6X3[3][i] = a0i * 0.04166667f + a1i * 0.08333333f + a2i * 0.1666666666666667f;
|
||||
midResult6X3[4][i] = a0i * 0.04166667f - a1i * 0.08333333f + a2i * 0.1666666666666667f;
|
||||
midResult6X3[5][i] = a2i;
|
||||
}
|
||||
|
||||
for (int i = 0; i < CONV3X3_WINO_IN; ++i) {
|
||||
auto curRowDst = dst;
|
||||
curRowDst[0 * step] = 0.25f * midResult6X3[i][0];
|
||||
curRowDst[1 * step] = (midResult6X3[i][0] + midResult6X3[i][1] + midResult6X3[i][2]) * -0.1666666666666667f;
|
||||
curRowDst[2 * step] = (midResult6X3[i][0] - midResult6X3[i][1] + midResult6X3[i][2]) * -0.1666666666666667f;
|
||||
curRowDst[3 * step] = midResult6X3[i][0] * 0.04166667f + midResult6X3[i][1] * 0.08333333f +
|
||||
midResult6X3[i][2] * 0.1666666666666667f;
|
||||
curRowDst[4 * step] = midResult6X3[i][0] * 0.04166667f - midResult6X3[i][1] * 0.08333333f +
|
||||
midResult6X3[i][2] * 0.1666666666666667f;
|
||||
curRowDst[5 * step] = midResult6X3[i][2];
|
||||
dst += CONV3X3_WINO_IN * step;
|
||||
}
|
||||
}
|
||||
|
||||
static void sourceTransform_wino_4x4_3x3(const FLOAT16* src, FLOAT16* dst, int step) {
|
||||
FLOAT16 midResult[6][6][ARMV82_CHANNEL_UNIT];
|
||||
|
||||
float16x8_t value_4 = vmovq_n_f16(4);
|
||||
float16x8_t value_neg_5 = vmovq_n_f16(-5);
|
||||
float16x8_t value_neg_4 = vmovq_n_f16(-4);
|
||||
float16x8_t value_2 = vmovq_n_f16(2);
|
||||
|
||||
for (int i = 0; i < CONV3X3_WINO_IN; ++i) {
|
||||
float16x8_t a0i = vld1q_f16(src + (0 * CONV3X3_WINO_IN + i) * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t a1i = vld1q_f16(src + (1 * CONV3X3_WINO_IN + i) * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t a2i = vld1q_f16(src + (2 * CONV3X3_WINO_IN + i) * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t a3i = vld1q_f16(src + (3 * CONV3X3_WINO_IN + i) * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t a4i = vld1q_f16(src + (4 * CONV3X3_WINO_IN + i) * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t a5i = vld1q_f16(src + (5 * CONV3X3_WINO_IN + i) * ARMV82_CHANNEL_UNIT);
|
||||
|
||||
float16x8_t b0 = vfmaq_f16(a4i, a2i, value_neg_4);
|
||||
float16x8_t b1 = vfmaq_f16(a3i, a1i, value_neg_4);
|
||||
float16x8_t b2 = vsubq_f16(a4i, a2i);
|
||||
float16x8_t b3 = vmulq_f16(vsubq_f16(a3i, a1i), value_2);
|
||||
float16x8_t b4 = vfmaq_f16(a4i, a0i, value_4);
|
||||
float16x8_t b5 = vfmaq_f16(a5i, a1i, value_4);
|
||||
|
||||
float16x8_t r0 = vfmaq_f16(b4, value_neg_5, a2i);
|
||||
float16x8_t r1 = vaddq_f16(b0, b1);
|
||||
float16x8_t r2 = vsubq_f16(b0, b1);
|
||||
float16x8_t r3 = vaddq_f16(b2, b3);
|
||||
float16x8_t r4 = vsubq_f16(b2, b3);
|
||||
float16x8_t r5 = vfmaq_f16(b5, value_neg_5, a3i);
|
||||
|
||||
vst1q_f16(midResult[0][i], r0);
|
||||
vst1q_f16(midResult[1][i], r1);
|
||||
vst1q_f16(midResult[2][i], r2);
|
||||
vst1q_f16(midResult[3][i], r3);
|
||||
vst1q_f16(midResult[4][i], r4);
|
||||
vst1q_f16(midResult[5][i], r5);
|
||||
}
|
||||
|
||||
for (int i = 0; i < CONV3X3_WINO_IN; ++i) {
|
||||
float16x8_t a0i = vld1q_f16(midResult[i][0]);
|
||||
float16x8_t a1i = vld1q_f16(midResult[i][1]);
|
||||
float16x8_t a2i = vld1q_f16(midResult[i][2]);
|
||||
float16x8_t a3i = vld1q_f16(midResult[i][3]);
|
||||
float16x8_t a4i = vld1q_f16(midResult[i][4]);
|
||||
float16x8_t a5i = vld1q_f16(midResult[i][5]);
|
||||
|
||||
float16x8_t b0 = vfmaq_f16(a4i, a2i, value_neg_4);
|
||||
float16x8_t b1 = vfmaq_f16(a3i, a1i, value_neg_4);
|
||||
float16x8_t b2 = vsubq_f16(a4i, a2i);
|
||||
float16x8_t b3 = vmulq_f16(vsubq_f16(a3i, a1i), value_2);
|
||||
float16x8_t b4 = vfmaq_f16(a4i, a0i, value_4);
|
||||
float16x8_t b5 = vfmaq_f16(a5i, a1i, value_4);
|
||||
|
||||
float16x8_t r0 = vfmaq_f16(b4, value_neg_5, a2i);
|
||||
float16x8_t r1 = vaddq_f16(b0, b1);
|
||||
float16x8_t r2 = vsubq_f16(b0, b1);
|
||||
float16x8_t r3 = vaddq_f16(b2, b3);
|
||||
float16x8_t r4 = vsubq_f16(b2, b3);
|
||||
float16x8_t r5 = vfmaq_f16(b5, value_neg_5, a3i);
|
||||
|
||||
vst1q_f16(dst + 0 * step, r0);
|
||||
vst1q_f16(dst + 1 * step, r1);
|
||||
vst1q_f16(dst + 2 * step, r2);
|
||||
vst1q_f16(dst + 3 * step, r3);
|
||||
vst1q_f16(dst + 4 * step, r4);
|
||||
vst1q_f16(dst + 5 * step, r5);
|
||||
dst += CONV3X3_WINO_IN * step;
|
||||
}
|
||||
}
|
||||
|
||||
static void dstTransform_wino_4x4_3x3(const FLOAT16* src, const FLOAT16* bias, bool relu, bool relu6, FLOAT16* dst,
|
||||
int step) {
|
||||
FLOAT16 midResult[4][6][ARMV82_CHANNEL_UNIT];
|
||||
|
||||
float16x8_t value_0 = vmovq_n_f16(0);
|
||||
float16x8_t value_6 = vmovq_n_f16(6);
|
||||
float16x8_t value_2 = vmovq_n_f16(2);
|
||||
float16x8_t value_4 = vmovq_n_f16(4);
|
||||
float16x8_t value_8 = vmovq_n_f16(8);
|
||||
|
||||
float16x8_t value_bias = vld1q_f16(bias);
|
||||
|
||||
for (int i = 0; i < CONV3X3_WINO_IN; ++i) {
|
||||
float16x8_t a0i = vld1q_f16(src + (CONV3X3_WINO_IN * 0 + i) * step);
|
||||
float16x8_t a1i = vld1q_f16(src + (CONV3X3_WINO_IN * 1 + i) * step);
|
||||
float16x8_t a2i = vld1q_f16(src + (CONV3X3_WINO_IN * 2 + i) * step);
|
||||
float16x8_t a3i = vld1q_f16(src + (CONV3X3_WINO_IN * 3 + i) * step);
|
||||
float16x8_t a4i = vld1q_f16(src + (CONV3X3_WINO_IN * 4 + i) * step);
|
||||
float16x8_t a5i = vld1q_f16(src + (CONV3X3_WINO_IN * 5 + i) * step);
|
||||
|
||||
float16x8_t b0 = vaddq_f16(a1i, a2i);
|
||||
float16x8_t b1 = vaddq_f16(a3i, a4i);
|
||||
float16x8_t b2 = vsubq_f16(a1i, a2i);
|
||||
float16x8_t b3 = vsubq_f16(a3i, a4i);
|
||||
|
||||
float16x8_t r0 = vaddq_f16(vaddq_f16(b0, b1), a0i);
|
||||
float16x8_t r1 = vfmaq_f16(b2, b3, value_2);
|
||||
float16x8_t r2 = vfmaq_f16(b0, b1, value_4);
|
||||
float16x8_t r3 = vaddq_f16(a5i, vfmaq_f16(b2, b3, value_8));
|
||||
|
||||
vst1q_f16(midResult[0][i], r0);
|
||||
vst1q_f16(midResult[1][i], r1);
|
||||
vst1q_f16(midResult[2][i], r2);
|
||||
vst1q_f16(midResult[3][i], r3);
|
||||
}
|
||||
|
||||
for (int i = 0; i < CONV3X3_WINO_OUT; ++i) {
|
||||
float16x8_t a0i = vld1q_f16(midResult[i][0]);
|
||||
float16x8_t a1i = vld1q_f16(midResult[i][1]);
|
||||
float16x8_t a2i = vld1q_f16(midResult[i][2]);
|
||||
float16x8_t a3i = vld1q_f16(midResult[i][3]);
|
||||
float16x8_t a4i = vld1q_f16(midResult[i][4]);
|
||||
float16x8_t a5i = vld1q_f16(midResult[i][5]);
|
||||
|
||||
float16x8_t b0 = vaddq_f16(a1i, a2i);
|
||||
float16x8_t b1 = vaddq_f16(a3i, a4i);
|
||||
float16x8_t b2 = vsubq_f16(a1i, a2i);
|
||||
float16x8_t b3 = vsubq_f16(a3i, a4i);
|
||||
|
||||
float16x8_t r0 = vaddq_f16(vaddq_f16(b0, b1), a0i);
|
||||
float16x8_t r1 = vfmaq_f16(b2, b3, value_2);
|
||||
float16x8_t r2 = vfmaq_f16(b0, b1, value_4);
|
||||
float16x8_t r3 = vaddq_f16(a5i, vfmaq_f16(b2, b3, value_8));
|
||||
|
||||
r0 = vaddq_f16(r0, value_bias);
|
||||
r1 = vaddq_f16(r1, value_bias);
|
||||
r2 = vaddq_f16(r2, value_bias);
|
||||
r3 = vaddq_f16(r3, value_bias);
|
||||
|
||||
if (relu) {
|
||||
r0 = vmaxq_f16(r0, value_0);
|
||||
r1 = vmaxq_f16(r1, value_0);
|
||||
r2 = vmaxq_f16(r2, value_0);
|
||||
r3 = vmaxq_f16(r3, value_0);
|
||||
}
|
||||
if (relu6) {
|
||||
r0 = vmaxq_f16(r0, value_0);
|
||||
r1 = vmaxq_f16(r1, value_0);
|
||||
r2 = vmaxq_f16(r2, value_0);
|
||||
r3 = vmaxq_f16(r3, value_0);
|
||||
r0 = vminq_f16(r0, value_6);
|
||||
r1 = vminq_f16(r1, value_6);
|
||||
r2 = vminq_f16(r2, value_6);
|
||||
r3 = vminq_f16(r3, value_6);
|
||||
}
|
||||
|
||||
vst1q_f16(dst + 0 * ARMV82_CHANNEL_UNIT, r0);
|
||||
vst1q_f16(dst + 1 * ARMV82_CHANNEL_UNIT, r1);
|
||||
vst1q_f16(dst + 2 * ARMV82_CHANNEL_UNIT, r2);
|
||||
vst1q_f16(dst + 3 * ARMV82_CHANNEL_UNIT, r3);
|
||||
dst += CONV3X3_WINO_OUT * ARMV82_CHANNEL_UNIT;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Arm82Convolution3x3::Arm82Convolution3x3(const MNN::Convolution2D* convParam, Backend* bn) : Execution(bn) {
|
||||
const auto commonParam = convParam->common();
|
||||
mCommon = commonParam;
|
||||
int inputChannel = commonParam->inputCount();
|
||||
const int outputChannel = commonParam->outputCount();
|
||||
|
||||
if (inputChannel == 0) {
|
||||
if (convParam->quanParameter()) {
|
||||
inputChannel = convParam->quanParameter()->buffer()->size() / (2 * 9 * outputChannel);
|
||||
} else {
|
||||
inputChannel = convParam->weight()->size() / (9 * outputChannel);
|
||||
}
|
||||
}
|
||||
|
||||
const int icDiv8 = UP_DIV(inputChannel, ARMV82_CHANNEL_UNIT);
|
||||
const int ocDiv8 = UP_DIV(outputChannel, ARMV82_CHANNEL_UNIT);
|
||||
mRelu = mCommon->relu();
|
||||
mRelu6 = mCommon->relu6();
|
||||
// transform weight
|
||||
{
|
||||
mWeightFp16.reset(
|
||||
Tensor::createDevice<uint16_t>({icDiv8 * ocDiv8 * CONV3X3_WEIGHT_UNIT * ARMV82_CHANNEL_UNIT}));
|
||||
mValid = bn->onAcquireBuffer(mWeightFp16.get(), Backend::STATIC);
|
||||
if (!mValid) {
|
||||
return;
|
||||
}
|
||||
|
||||
memset(mWeightFp16->host<uint16_t>(), 0, mWeightFp16->size());
|
||||
|
||||
// Set source size align avoid of heap error
|
||||
std::vector<FLOAT16> weightFp16(ocDiv8 * ARMV82_CHANNEL_UNIT * inputChannel * CONV3X3_WINO_KER * CONV3X3_WINO_KER, 0);
|
||||
const FLOAT16* fp16WeightPtr = weightFp16.data();
|
||||
if (convParam->quanParameter()) {
|
||||
MNN_ASSERT((convParam->quanParameter()->type() == 3) || (convParam->quanParameter()->type() == 4));
|
||||
if (convParam->quanParameter()->type() == 3) {
|
||||
// the data type of weight is fp16
|
||||
::memcpy(weightFp16.data(), convParam->quanParameter()->buffer()->data(), convParam->quanParameter()->buffer()->size());
|
||||
}
|
||||
if (convParam->quanParameter()->type() == 4) {
|
||||
std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
|
||||
quanCommon = ConvolutionCommon::load(convParam->quanParameter(), true);
|
||||
int weightCount = convParam->quanParameter()->buffer()->size();
|
||||
MNNQuantizeFP16(weightFp16.data(), quanCommon->weightFloat.get(), weightCount);
|
||||
}
|
||||
} else {
|
||||
// the data type of weight is fp32, then quantize weight to be fp16 data type
|
||||
int size = convParam->weight()->size();
|
||||
MNNQuantizeFP16(weightFp16.data(), convParam->weight()->data(), size);
|
||||
}
|
||||
|
||||
const auto srcWeightPtr = fp16WeightPtr;
|
||||
auto dstWeightPtr = mWeightFp16->host<FLOAT16>();
|
||||
|
||||
auto transformWeight = [&](int ocUnit, int ocStart, int ocEnd, FLOAT16* weight) {
|
||||
for (int oc = ocStart; oc < ocEnd; ++oc) {
|
||||
const int oci = oc / ocUnit;
|
||||
const int ocj = oc % ocUnit;
|
||||
const auto srcWeightOcPtr = srcWeightPtr + oc * inputChannel * CONV3X3_WINO_KER * CONV3X3_WINO_KER;
|
||||
auto dstWeightOcPtr = weight + oci * icDiv8 * ARMV82_CHANNEL_UNIT * ocUnit + ocj;
|
||||
for (int ic = 0; ic < inputChannel; ++ic) {
|
||||
const auto srcWeightIcPtr = srcWeightOcPtr + ic * CONV3X3_WINO_KER * CONV3X3_WINO_KER;
|
||||
auto dstWeightIcPtr = dstWeightOcPtr + ic * ocUnit;
|
||||
|
||||
kernelTransform_wino_4x4_3x3(srcWeightIcPtr, dstWeightIcPtr,
|
||||
icDiv8 * ocDiv8 * ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const int ocDivDoubleUnit = ocDiv8 / 2;
|
||||
if (ocDivDoubleUnit > 0) {
|
||||
transformWeight((ARMV82_CHANNEL_UNIT * 2), 0, ocDivDoubleUnit * (ARMV82_CHANNEL_UNIT * 2), dstWeightPtr);
|
||||
}
|
||||
if (ocDiv8 % 2 == 1) {
|
||||
transformWeight(ARMV82_CHANNEL_UNIT, ocDivDoubleUnit * (ARMV82_CHANNEL_UNIT * 2), outputChannel,
|
||||
dstWeightPtr);
|
||||
}
|
||||
}
|
||||
|
||||
mBiasFp16.reset(Tensor::createDevice<uint16_t>({ocDiv8 * ARMV82_CHANNEL_UNIT}));
|
||||
mValid = bn->onAcquireBuffer(mBiasFp16.get(), Backend::STATIC);
|
||||
if (!mValid) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO, bias is fp32, save bias also in fp16?
|
||||
auto biasDstPtr = mBiasFp16->host<FLOAT16>();
|
||||
memset(biasDstPtr, 0, mBiasFp16->size());
|
||||
MNNQuantizeFP16(biasDstPtr, convParam->bias()->data(), outputChannel);
|
||||
}
|
||||
|
||||
Arm82Convolution3x3::~Arm82Convolution3x3() {
|
||||
if (nullptr != mWeightFp16) {
|
||||
backend()->onReleaseBuffer(mWeightFp16.get(), Backend::STATIC);
|
||||
}
|
||||
if (nullptr != mBiasFp16) {
|
||||
backend()->onReleaseBuffer(mBiasFp16.get(), Backend::STATIC);
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode Arm82Convolution3x3::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
||||
mPadX = mCommon->padX();
|
||||
mPadY = mCommon->padY();
|
||||
if (mCommon->padMode() == PadMode_SAME) {
|
||||
int kernelWidthSize = (mCommon->kernelX() - 1) * mCommon->dilateX() + 1;
|
||||
int kernelHeightSize = (mCommon->kernelY() - 1) * mCommon->dilateY() + 1;
|
||||
|
||||
int padNeededWidth = (output->width() - 1) * mCommon->strideX() + kernelWidthSize - input->width();
|
||||
int padNeededHeight = (output->height() - 1) * mCommon->strideY() + kernelHeightSize - input->height();
|
||||
mPadX = padNeededWidth / 2;
|
||||
mPadY = padNeededHeight / 2;
|
||||
}
|
||||
|
||||
mThreadNums = std::max(static_cast<Arm82Backend*>(backend())->numberThread(), 1);
|
||||
mTransformBuffer.buffer().dimensions = 4;
|
||||
mTransformBuffer.setType(DataType_DT_BFLOAT16);
|
||||
mTransformBuffer.setLength(0, mThreadNums);
|
||||
mTransformBuffer.setLength(1, CONV3X3_WINO_TILE);
|
||||
mTransformBuffer.setLength(
|
||||
2, UP_DIV(input->channel(), ARMV82_CHANNEL_UNIT) + UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT) + 1);
|
||||
mTransformBuffer.setLength(3, CONV3X3_WINO_SRC_NUM);
|
||||
TensorUtils::setLinearLayout(&mTransformBuffer);
|
||||
|
||||
bool allocSuccess = backend()->onAcquireBuffer(&mTransformBuffer, Backend::DYNAMIC);
|
||||
if (!allocSuccess) {
|
||||
return OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
mDummyBias.buffer().dimensions = 1;
|
||||
mDummyBias.setType(DataType_DT_BFLOAT16);
|
||||
mDummyBias.setLength(0, UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT) * ARMV82_CHANNEL_UNIT);
|
||||
allocSuccess = backend()->onAcquireBuffer(&mDummyBias, Backend::DYNAMIC);
|
||||
if (!allocSuccess) {
|
||||
return OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
backend()->onReleaseBuffer(&mTransformBuffer, Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(&mDummyBias, Backend::DYNAMIC);
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode Arm82Convolution3x3::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
const int batch = input->batch();
|
||||
const int ih = input->height();
|
||||
const int iw = input->width();
|
||||
const int ihw = ih * iw;
|
||||
const int icDiv8 = UP_DIV(input->channel(), ARMV82_CHANNEL_UNIT);
|
||||
const int oh = output->height();
|
||||
const int ow = output->width();
|
||||
const int ohw = oh * ow;
|
||||
const int ocDiv8 = UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT);
|
||||
|
||||
const int hUnit = UP_DIV(oh, CONV3X3_WINO_OUT);
|
||||
const int wUnit = UP_DIV(ow, CONV3X3_WINO_OUT);
|
||||
|
||||
const int hPadded = hUnit * CONV3X3_WINO_OUT - oh;
|
||||
const int wPadded = wUnit * CONV3X3_WINO_OUT - ow;
|
||||
|
||||
const int outUnitCount = hUnit * wUnit;
|
||||
const int tileCount = UP_DIV(outUnitCount, CONV3X3_WINO_TILE);
|
||||
|
||||
const auto weightPtr = mWeightFp16->host<FLOAT16>();
|
||||
const auto biasDummyPtr = mDummyBias.host<FLOAT16>();
|
||||
const auto biasPtr = mBiasFp16->host<FLOAT16>();
|
||||
|
||||
memset(mDummyBias.host<FLOAT16>(), 0, mDummyBias.size());
|
||||
|
||||
auto srcGetAndTransformFunc = [=](int xIndex, int realTile, const FLOAT16* srcOrigin, FLOAT16* transformedBuffer,
|
||||
FLOAT16* tempBuffer) {
|
||||
memset(tempBuffer, 0, CONV3X3_WINO_TILE * CONV3X3_WINO_SRC_NUM * sizeof(FLOAT16));
|
||||
for (int tindex = 0; tindex < realTile; ++tindex) {
|
||||
int index = xIndex + tindex;
|
||||
int hindex = index / wUnit;
|
||||
int windex = index % wUnit;
|
||||
|
||||
int srcX = windex * CONV3X3_WINO_OUT - mPadX;
|
||||
int srcY = hindex * CONV3X3_WINO_OUT - mPadY;
|
||||
int sy = ALIMAX(0, srcY) - srcY;
|
||||
int ey = ALIMIN(srcY + CONV3X3_WINO_IN, ih) - srcY;
|
||||
int sx = ALIMAX(0, srcX) - srcX;
|
||||
int ex = ALIMIN(srcX + CONV3X3_WINO_IN, iw) - srcX;
|
||||
|
||||
const auto srcStart = srcOrigin + (srcX + srcY * iw) * ARMV82_CHANNEL_UNIT;
|
||||
auto curTransPtr = transformedBuffer + tindex * ARMV82_CHANNEL_UNIT;
|
||||
auto curTempBuffer = tempBuffer + tindex * CONV3X3_WINO_SRC_NUM;
|
||||
|
||||
for (int c = 0; c < icDiv8; ++c) {
|
||||
const auto curChannelSrcPtr = srcStart + c * ihw * ARMV82_CHANNEL_UNIT;
|
||||
auto curChannelTransPtr = curTransPtr + c * CONV3X3_WINO_TILE * ARMV82_CHANNEL_UNIT;
|
||||
if (ex > sx) {
|
||||
for (int yy = sy; yy < ey; ++yy) {
|
||||
const auto srcPtr = curChannelSrcPtr + yy * iw * ARMV82_CHANNEL_UNIT;
|
||||
auto dstPtr = curTempBuffer + yy * CONV3X3_WINO_IN * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
memcpy(dstPtr + ARMV82_CHANNEL_UNIT * sx, srcPtr + ARMV82_CHANNEL_UNIT * sx,
|
||||
(ex - sx) * sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
}
|
||||
}
|
||||
|
||||
sourceTransform_wino_4x4_3x3(curTempBuffer, curChannelTransPtr,
|
||||
ARMV82_CHANNEL_UNIT * CONV3X3_WINO_TILE * icDiv8);
|
||||
}
|
||||
}
|
||||
|
||||
// shuffel channel
|
||||
if (realTile > (CONV3X3_WINO_TILE / 2)) {
|
||||
MNNShuffleChannelC8(transformedBuffer, transformedBuffer,
|
||||
(size_t)(icDiv8 * CONV3X3_WINO_IN * CONV3X3_WINO_IN), 0);
|
||||
} else {
|
||||
for (int i = 0; i < CONV3X3_WINO_IN * CONV3X3_WINO_IN; ++i) {
|
||||
auto dst = transformedBuffer + i * ARMV82_CHANNEL_UNIT * CONV3X3_WINO_TILE * icDiv8;
|
||||
MNNShuffleChannelC8(dst, dst, (size_t)(icDiv8), 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto dstTransformAndSave = [=](int xIndex, int realTile, const FLOAT16* transformedBuffer, const FLOAT16* bias,
|
||||
bool relu, bool relu6, FLOAT16* dstOrigin, FLOAT16* tempBuffer) {
|
||||
for (int tindex = 0; tindex < realTile; ++tindex) {
|
||||
int index = xIndex + tindex;
|
||||
int hindex = index / wUnit;
|
||||
int windex = index % wUnit;
|
||||
int dstX = windex * CONV3X3_WINO_OUT;
|
||||
int dstY = hindex * CONV3X3_WINO_OUT;
|
||||
|
||||
const auto curTransPtr = transformedBuffer + tindex * ARMV82_CHANNEL_UNIT;
|
||||
auto dstStartPtr = dstOrigin + (dstX + dstY * ow) * ARMV82_CHANNEL_UNIT;
|
||||
auto curTempBuffer = tempBuffer + tindex * CONV3X3_WINO_SRC_NUM;
|
||||
|
||||
int hReamin = CONV3X3_WINO_OUT;
|
||||
int wReamin = CONV3X3_WINO_OUT;
|
||||
|
||||
if (hindex == (hUnit - 1)) {
|
||||
hReamin = CONV3X3_WINO_OUT - hPadded;
|
||||
}
|
||||
if (windex == (wUnit - 1)) {
|
||||
wReamin = CONV3X3_WINO_OUT - wPadded;
|
||||
}
|
||||
|
||||
for (int z = 0; z < ocDiv8; ++z) {
|
||||
const auto curChannelTransPtr = curTransPtr + z * CONV3X3_WINO_TILE * ARMV82_CHANNEL_UNIT;
|
||||
auto dstZ = dstStartPtr + z * ohw * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
dstTransform_wino_4x4_3x3(curChannelTransPtr, bias + z * ARMV82_CHANNEL_UNIT, relu, relu6,
|
||||
curTempBuffer, ocDiv8 * CONV3X3_WINO_TILE * ARMV82_CHANNEL_UNIT);
|
||||
|
||||
// save 4x4 outputs from tempBuffer
|
||||
for (int i = 0; i < hReamin; ++i) {
|
||||
memcpy(dstZ + i * ow * ARMV82_CHANNEL_UNIT,
|
||||
curTempBuffer + i * CONV3X3_WINO_OUT * ARMV82_CHANNEL_UNIT,
|
||||
sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT * wReamin);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto threadFunction = [&](size_t tId, size_t tileStart, int tileStep, int tileEnd, const FLOAT16* srcOrigin,
|
||||
FLOAT16* dstOrigin) {
|
||||
auto curThreadTransformPtr = mTransformBuffer.host<FLOAT16>() + tId * mTransformBuffer.stride(0);
|
||||
auto srcTransformedPtr = curThreadTransformPtr;
|
||||
auto dstTransformedPtr = curThreadTransformPtr + CONV3X3_WINO_TILE * CONV3X3_WINO_SRC_NUM * icDiv8;
|
||||
auto tempBufferPtr = curThreadTransformPtr + CONV3X3_WINO_TILE * CONV3X3_WINO_SRC_NUM * (icDiv8 + ocDiv8);
|
||||
|
||||
for (size_t tindex = tileStart; tindex < tileEnd; tindex += tileStep) {
|
||||
int xIndex = (int)tindex * CONV3X3_WINO_TILE;
|
||||
int xRemain = outUnitCount - xIndex;
|
||||
int realTileNum = xRemain > CONV3X3_WINO_TILE ? CONV3X3_WINO_TILE : xRemain;
|
||||
|
||||
srcGetAndTransformFunc(xIndex, realTileNum, srcOrigin, srcTransformedPtr, tempBufferPtr);
|
||||
|
||||
// matmul
|
||||
for (int i = 0; i < CONV3X3_WINO_IN * CONV3X3_WINO_IN; ++i) {
|
||||
MNNGemmFP16C8_UNIT(dstTransformedPtr + i * ocDiv8 * CONV3X3_WINO_TILE * ARMV82_CHANNEL_UNIT,
|
||||
srcTransformedPtr + i * ARMV82_CHANNEL_UNIT * CONV3X3_WINO_TILE * icDiv8,
|
||||
weightPtr + i * icDiv8 * ocDiv8 * ARMV82_CHANNEL_UNIT * ARMV82_CHANNEL_UNIT,
|
||||
biasDummyPtr, icDiv8, ARMV82_CHANNEL_UNIT * CONV3X3_WINO_TILE * sizeof(FLOAT16),
|
||||
ocDiv8, 0, 0, realTileNum);
|
||||
}
|
||||
|
||||
dstTransformAndSave(xIndex, realTileNum, dstTransformedPtr, biasPtr, mRelu, mRelu6, dstOrigin,
|
||||
tempBufferPtr);
|
||||
}
|
||||
};
|
||||
|
||||
const auto srcOriginPtr = input->host<FLOAT16>();
|
||||
auto dstOriginPtr = output->host<FLOAT16>();
|
||||
const int inBatchStride = icDiv8 * ihw * ARMV82_CHANNEL_UNIT;
|
||||
const int outBatchStride = ocDiv8 * ohw * ARMV82_CHANNEL_UNIT;
|
||||
for (int bIndex = 0; bIndex < batch; ++bIndex) {
|
||||
const auto curSrcBatchPtr = srcOriginPtr + bIndex * inBatchStride;
|
||||
auto curDstBatchPtr = dstOriginPtr + bIndex * outBatchStride;
|
||||
|
||||
if (tileCount >= mThreadNums) {
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNums)
|
||||
threadFunction((int)tId, (int)tId, mThreadNums, (tileCount / mThreadNums) * mThreadNums, curSrcBatchPtr,
|
||||
curDstBatchPtr);
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
MNN_CONCURRENCY_END();
|
||||
#else
|
||||
MNN_CONCURRENCY_END();
|
||||
#endif
|
||||
}
|
||||
if (tileCount % mThreadNums != 0) {
|
||||
threadFunction(0, (tileCount / mThreadNums) * mThreadNums, 1, tileCount, curSrcBatchPtr, curDstBatchPtr);
|
||||
}
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
|
@ -1,43 +0,0 @@
|
|||
//
|
||||
// Arm82Convolution3x3.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2020/02/04.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
|
||||
#ifndef Arm82Convolution3x3_hpp
|
||||
#define Arm82Convolution3x3_hpp
|
||||
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
class Arm82Convolution3x3 : public Execution {
|
||||
public:
|
||||
Arm82Convolution3x3(const MNN::Convolution2D *convParam, Backend *bn);
|
||||
virtual ~Arm82Convolution3x3();
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
int mTileCount;
|
||||
int mThreadNums;
|
||||
int mPadX;
|
||||
int mPadY;
|
||||
bool mRelu;
|
||||
bool mRelu6;
|
||||
std::shared_ptr<Tensor> mWeightFp16;
|
||||
std::shared_ptr<Tensor> mBiasFp16;
|
||||
|
||||
Tensor mTransformBuffer;
|
||||
Tensor mDummyBias;
|
||||
const Convolution2DCommon *mCommon;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,362 +0,0 @@
|
|||
//
|
||||
// Arm82ConvolutionDepthwise.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2020/01/07.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#include "backend/arm82/Arm82ConvolutionDepthwise.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/Macro.h"
|
||||
#include "backend/arm82/Arm82OptFunc.hpp"
|
||||
#include "core/ConvolutionCommon.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
void MNNLineDepthWiseFp16C8Unit(FLOAT16* dst, const FLOAT16* src, const FLOAT16* weight, const FLOAT16* bias_z,
|
||||
size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step,
|
||||
size_t dilateY_step, size_t relu, size_t relu6);
|
||||
}
|
||||
|
||||
namespace MNN {
|
||||
|
||||
static void MNNDepthWiseFp16C8Unit(FLOAT16* dst, const FLOAT16* src, const FLOAT16* weight, const FLOAT16* bias,
|
||||
size_t fw, size_t fh, size_t weight_y_step, size_t dilateX_step, size_t dilateY_step,
|
||||
size_t relu, size_t relu6) {
|
||||
int fx, fy;
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t acc_value = vld1q_f16(bias);
|
||||
#else
|
||||
FLOAT16 acc_value[ARMV82_CHANNEL_UNIT];
|
||||
memcpy(acc_value, bias, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
#endif
|
||||
|
||||
for (fy = 0; fy < fh; ++fy) {
|
||||
const auto src_y = src + fy * dilateY_step;
|
||||
const auto weight_y = weight + fy * weight_y_step;
|
||||
for (fx = 0; fx < fw; ++fx) {
|
||||
const auto weight_x = weight_y + fx * ARMV82_CHANNEL_UNIT;
|
||||
const auto src_x = src_y + fx * dilateX_step;
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t src_x_value = vld1q_f16(src_x);
|
||||
float16x8_t weight_x_value = vld1q_f16(weight_x);
|
||||
acc_value = vfmaq_f16(acc_value, src_x_value, weight_x_value);
|
||||
#else
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
acc_value[j] += src_x[j] * weight_x[j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
if (relu) {
|
||||
float16x8_t zero_value = vdupq_n_f16(float16_t(0.0));
|
||||
acc_value = vmaxq_f16(acc_value, zero_value);
|
||||
}
|
||||
if (relu6) {
|
||||
float16x8_t zero_value = vdupq_n_f16(float16_t(0.0));
|
||||
float16x8_t six_value = vdupq_n_f16(float16_t(6.0));
|
||||
acc_value = vmaxq_f16(acc_value, zero_value);
|
||||
acc_value = vminq_f16(acc_value, six_value);
|
||||
}
|
||||
vst1q_f16(dst, acc_value);
|
||||
#else
|
||||
if (relu) {
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
if (acc_value[j] < 0) {
|
||||
acc_value[j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (relu6) {
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
if (acc_value[j] < 0) {
|
||||
acc_value[j] = 0;
|
||||
}
|
||||
if (acc_value[j] > 6) {
|
||||
acc_value[j] = 6.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
memcpy(dst, acc_value, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef MNN_USE_NEON
|
||||
static void MNNLineDepthWiseFp16C8Unit(FLOAT16* dst, const FLOAT16* src, const FLOAT16* weight, const FLOAT16* bias_z,
|
||||
size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step,
|
||||
size_t dilateY_step, size_t relu, size_t relu6) {
|
||||
int dx, fx, fy;
|
||||
for (dx = 0; dx < width; ++dx) {
|
||||
auto dst_x = dst + dx * ARMV82_CHANNEL_UNIT;
|
||||
FLOAT16 dst_temp[ARMV82_CHANNEL_UNIT];
|
||||
memcpy(dst_temp, bias_z, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
|
||||
const auto src_z = src + src_w_step * dx;
|
||||
|
||||
for (fy = 0; fy < fh; ++fy) {
|
||||
const auto src_y = src_z + fy * dilateY_step;
|
||||
const auto weight_y = weight + fy * fw * ARMV82_CHANNEL_UNIT;
|
||||
for (fx = 0; fx < fw; ++fx) {
|
||||
const auto src_x = src_y + fx * dilateX_step;
|
||||
const auto weight_x = weight_y + fx * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
dst_temp[j] += src_x[j] * weight_x[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (relu) {
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
if (dst_temp[j] < 0) {
|
||||
dst_temp[j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (relu6) {
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
if (dst_temp[j] < 0) {
|
||||
dst_temp[j] = 0;
|
||||
}
|
||||
if (dst_temp[j] > 6) {
|
||||
dst_temp[j] = 6.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(dst_x, dst_temp, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Arm82ConvolutionDepthwise::Arm82ConvolutionDepthwise(const MNN::Convolution2D* convParam, Backend* bn) : Execution(bn) {
|
||||
const auto commonParam = convParam->common();
|
||||
mCommon = commonParam;
|
||||
mRelu = commonParam->relu();
|
||||
mRelu6 = commonParam->relu6();
|
||||
const int kx = commonParam->kernelX();
|
||||
const int ky = commonParam->kernelY();
|
||||
const int kernelSize = kx * ky;
|
||||
|
||||
const int outputChannel = commonParam->outputCount();
|
||||
const int ocDivUnit = UP_DIV(outputChannel, ARMV82_CHANNEL_UNIT);
|
||||
const int weightSizeAlignLen = ocDivUnit * ARMV82_CHANNEL_UNIT * kernelSize;
|
||||
mWeightFp16.reset(Tensor::createDevice<uint16_t>({weightSizeAlignLen}));
|
||||
auto success = bn->onAcquireBuffer(mWeightFp16.get(), Backend::STATIC);
|
||||
if (!success) {
|
||||
mValid = false;
|
||||
return;
|
||||
}
|
||||
auto weightDstPtr = mWeightFp16->host<FLOAT16>();
|
||||
memset(weightDstPtr, 0, weightSizeAlignLen * sizeof(FLOAT16));
|
||||
|
||||
const FLOAT16* fp16WeightPtr = nullptr;
|
||||
std::vector<FLOAT16> weightFp16;
|
||||
if(convParam->quanParameter()){
|
||||
MNN_ASSERT((convParam->quanParameter()->type() == 3) || (convParam->quanParameter()->type() == 4));
|
||||
if (convParam->quanParameter()->type() == 3) {
|
||||
// the data type of weight is fp16
|
||||
fp16WeightPtr = reinterpret_cast<const FLOAT16 *>(convParam->quanParameter()->buffer()->data());
|
||||
}
|
||||
if (convParam->quanParameter()->type() == 4) {
|
||||
std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
|
||||
quanCommon = ConvolutionCommon::load(convParam->quanParameter(), true);
|
||||
int weightCount = convParam->quanParameter()->buffer()->size();
|
||||
weightFp16.resize(weightCount);
|
||||
MNNQuantizeFP16(weightFp16.data(), quanCommon->weightFloat.get(), weightCount);
|
||||
fp16WeightPtr = weightFp16.data();
|
||||
}
|
||||
} else {
|
||||
// the data type of weight is fp32, then quantize weight to be fp16 data type
|
||||
int size = convParam->weight()->size();
|
||||
weightFp16.resize(size);
|
||||
MNNQuantizeFP16(weightFp16.data(), convParam->weight()->data(), size);
|
||||
fp16WeightPtr = weightFp16.data();
|
||||
}
|
||||
|
||||
const auto weightSrcPtr = fp16WeightPtr;
|
||||
int cur = 0;
|
||||
for (int dz = 0; dz < outputChannel; ++dz) {
|
||||
const int dzi = dz / ARMV82_CHANNEL_UNIT;
|
||||
const int dzj = dz % ARMV82_CHANNEL_UNIT;
|
||||
|
||||
auto dstDz = weightDstPtr + dzi * kernelSize * ARMV82_CHANNEL_UNIT + dzj;
|
||||
for (int k = 0; k < kernelSize; ++k) {
|
||||
dstDz[k * ARMV82_CHANNEL_UNIT] = weightSrcPtr[cur++];
|
||||
}
|
||||
}
|
||||
mBiasFp16.reset(Tensor::createDevice<uint16_t>({ocDivUnit * ARMV82_CHANNEL_UNIT}));
|
||||
success = bn->onAcquireBuffer(mBiasFp16.get(), Backend::STATIC);
|
||||
if (!success) {
|
||||
mValid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO, bias is fp32, save bias also in fp16?
|
||||
auto biasDstPtr = mBiasFp16->host<FLOAT16>();
|
||||
memset(biasDstPtr, 0, mBiasFp16->size());
|
||||
|
||||
MNNQuantizeFP16(biasDstPtr, convParam->bias()->data(), outputChannel);
|
||||
}
|
||||
|
||||
Arm82ConvolutionDepthwise::~Arm82ConvolutionDepthwise() {
|
||||
if (mWeightFp16 != nullptr) {
|
||||
backend()->onReleaseBuffer(mWeightFp16.get(), Backend::STATIC);
|
||||
}
|
||||
if (mBiasFp16 != nullptr) {
|
||||
backend()->onReleaseBuffer(mBiasFp16.get(), Backend::STATIC);
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode Arm82ConvolutionDepthwise::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
||||
int padX = mCommon->padX();
|
||||
int padY = mCommon->padY();
|
||||
|
||||
if (mCommon->padMode() == PadMode_SAME) {
|
||||
int kernelWidthSize = (mCommon->kernelX() - 1) * mCommon->dilateX() + 1;
|
||||
int kernelHeightSize = (mCommon->kernelY() - 1) * mCommon->dilateY() + 1;
|
||||
|
||||
int padNeededWidth = (output->width() - 1) * mCommon->strideX() + kernelWidthSize - input->width();
|
||||
int padNeededHeight = (output->height() - 1) * mCommon->strideY() + kernelHeightSize - input->height();
|
||||
padX = padNeededWidth / 2;
|
||||
padY = padNeededHeight / 2;
|
||||
}
|
||||
|
||||
const int src_width = input->width();
|
||||
const int src_height = input->height();
|
||||
const int dst_width = output->width();
|
||||
const int dst_height = output->height();
|
||||
const int dst_depth_quad = UP_DIV(output->channel(), ARMV82_CHANNEL_UNIT);
|
||||
const int dst_z_step = dst_width * dst_height * ARMV82_CHANNEL_UNIT;
|
||||
const int src_z_step = src_width * src_height * ARMV82_CHANNEL_UNIT;
|
||||
const int dst_y_step = dst_width * ARMV82_CHANNEL_UNIT;
|
||||
const int src_y_step = src_width * ARMV82_CHANNEL_UNIT;
|
||||
const int strideY = mCommon->strideY();
|
||||
const int strideX = mCommon->strideX();
|
||||
const int dilateY = mCommon->dilateY();
|
||||
const int dilateX = mCommon->dilateX();
|
||||
const int dilateY_step = dilateY * src_width * ARMV82_CHANNEL_UNIT;
|
||||
const int dilateX_step = dilateX * ARMV82_CHANNEL_UNIT;
|
||||
const int kernel_height = mCommon->kernelY();
|
||||
const int kernel_width = mCommon->kernelX();
|
||||
const int weight_z_step = kernel_width * kernel_height * ARMV82_CHANNEL_UNIT;
|
||||
int l = 0, t = 0, r = dst_width, b = dst_height;
|
||||
for (; l * strideX - padX < 0; l++) {
|
||||
// do nothing
|
||||
}
|
||||
for (; t * strideY - padY < 0; t++) {
|
||||
// do nothing
|
||||
}
|
||||
for (; (r - 1) * strideX - padX + kernel_width * dilateX > src_width && r > l; r--) {
|
||||
// do nothing
|
||||
}
|
||||
for (; (b - 1) * strideY - padY + kernel_height * dilateY > src_height && b > t; b--) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
const auto weightPtr = mWeightFp16->host<FLOAT16>();
|
||||
const auto biasPtr = mBiasFp16->host<FLOAT16>();
|
||||
const int threadNumber = static_cast<Arm82Backend*>(backend())->numberThread();
|
||||
mThreadNumber = std::min(threadNumber, dst_depth_quad);
|
||||
auto runBasic = [=](FLOAT16* dst_z, const FLOAT16* src_z, const FLOAT16* weight_dz, const FLOAT16* bias_z, int L,
|
||||
int T, int R, int B) {
|
||||
for (int dy = T; dy < B; ++dy) {
|
||||
auto dst_y = dst_z + dy * dst_y_step;
|
||||
const int srcStartY = dy * strideY - padY;
|
||||
const auto src_y = src_z + srcStartY * src_y_step;
|
||||
const int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY)));
|
||||
const int efy = ALIMIN(kernel_height, (UP_DIV(src_height - srcStartY, dilateY)));
|
||||
for (int dx = L; dx < R; ++dx) {
|
||||
auto dst_x = dst_y + ARMV82_CHANNEL_UNIT * dx;
|
||||
const int srcStartX = dx * strideX - padX;
|
||||
const auto src_x = src_y + srcStartX * ARMV82_CHANNEL_UNIT;
|
||||
const int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX)));
|
||||
const int efx = ALIMIN(kernel_width, (UP_DIV(src_width - srcStartX, dilateX)));
|
||||
const int srcIndex = (sfx * dilateX + sfy * dilateY * src_width) * ARMV82_CHANNEL_UNIT;
|
||||
const int weightIndex = (kernel_width * sfy + sfx) * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
MNNDepthWiseFp16C8Unit(dst_x, src_x + srcIndex, weight_dz + weightIndex, bias_z, efx - sfx, efy - sfy,
|
||||
ARMV82_CHANNEL_UNIT * kernel_width, dilateX_step, dilateY_step,
|
||||
(size_t)mRelu, (size_t)mRelu6);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
mThreadFunction = [=](int tId, const FLOAT16* src, FLOAT16* dst) {
|
||||
for (int dz = tId; dz < dst_depth_quad; dz += mThreadNumber) {
|
||||
const auto src_z = src + dz * src_z_step;
|
||||
const auto weight_dz = weightPtr + dz * weight_z_step;
|
||||
const auto bias_dz = biasPtr + dz * ARMV82_CHANNEL_UNIT;
|
||||
auto dst_z = dst + dz * dst_z_step;
|
||||
runBasic(dst_z, src_z, weight_dz, bias_dz, 0, 0, dst_width, t);
|
||||
runBasic(dst_z, src_z, weight_dz, bias_dz, 0, b, dst_width, dst_height);
|
||||
runBasic(dst_z, src_z, weight_dz, bias_dz, 0, t, l, b);
|
||||
runBasic(dst_z, src_z, weight_dz, bias_dz, r, t, dst_width, b);
|
||||
if (r > l) {
|
||||
for (int dy = t; dy < b; ++dy) {
|
||||
const int srcStartY = dy * strideY - padY;
|
||||
const auto src_dy = src_z + srcStartY * src_y_step;
|
||||
auto dst_y = dst_z + dy * dst_y_step;
|
||||
MNNLineDepthWiseFp16C8Unit(
|
||||
dst_y + l * ARMV82_CHANNEL_UNIT, src_dy + (l * strideX - padX) * ARMV82_CHANNEL_UNIT, weight_dz,
|
||||
bias_dz, r - l, strideX * ARMV82_CHANNEL_UNIT, kernel_width, kernel_height, dilateX_step,
|
||||
dilateY_step, (size_t)mRelu, (size_t)mRelu6);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode Arm82ConvolutionDepthwise::onExecute(const std::vector<Tensor*>& inputs,
|
||||
const std::vector<Tensor*>& outputs) {
|
||||
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
const int batch = input->batch();
|
||||
|
||||
const int inBatchStride = ROUND_UP(input->channel(), ARMV82_CHANNEL_UNIT) * input->height() * input->width();
|
||||
const int outBatchStride = ROUND_UP(output->channel(), ARMV82_CHANNEL_UNIT) * output->height() * output->width();
|
||||
|
||||
const auto inputPtr = input->host<FLOAT16>();
|
||||
auto outputPtr = output->host<FLOAT16>();
|
||||
|
||||
for (int bIndex = 0; bIndex < batch; ++bIndex) {
|
||||
const auto srcOrigin = inputPtr + bIndex * inBatchStride;
|
||||
auto dstOrigin = outputPtr + bIndex * outBatchStride;
|
||||
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNumber)
|
||||
mThreadFunction((int)tId, srcOrigin, dstOrigin);
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
MNN_CONCURRENCY_END();
|
||||
#else
|
||||
MNN_CONCURRENCY_END();
|
||||
#endif
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
class Arm82ConvolutionDepthwiseCreator : public Arm82Backend::Arm82Creator {
|
||||
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
||||
const MNN::Op* op, Backend* backend) const override {
|
||||
return new Arm82ConvolutionDepthwise(op->main_as_Convolution2D(), backend);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_ARM82_OP_CREATOR(OpType_ConvolutionDepthwise, Arm82ConvolutionDepthwiseCreator);
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
|
@ -1,38 +0,0 @@
|
|||
//
|
||||
// Arm82ConvolutionDepthwise.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2020/01/07.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#ifndef Arm82ConvolutionDepthwise_hpp
|
||||
#define Arm82ConvolutionDepthwise_hpp
|
||||
|
||||
#include "MNN_generated.h"
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
class Arm82ConvolutionDepthwise : public Execution {
|
||||
public:
|
||||
Arm82ConvolutionDepthwise(const MNN::Convolution2D *convParam, Backend *bn);
|
||||
virtual ~Arm82ConvolutionDepthwise();
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Tensor> mWeightFp16;
|
||||
std::shared_ptr<Tensor> mBiasFp16;
|
||||
const Convolution2DCommon *mCommon;
|
||||
int mThreadNumber;
|
||||
bool mRelu;
|
||||
bool mRelu6;
|
||||
std::function<void(int tId, const FLOAT16 *src, FLOAT16 *dst)> mThreadFunction;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82ConvolutionDepthwise_hpp */
|
||||
|
||||
#endif
|
|
@ -5,17 +5,13 @@
|
|||
// Created by MNN on 2020/2/13.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifdef __aarch64__
|
||||
#include "backend/arm82/Arm82Eltwise.hpp"
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "Arm82Eltwise.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "MNN_generated.h"
|
||||
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
// Created by MNN on 2020/2/13.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Eltwise_hpp
|
||||
#define Arm82Eltwise_hpp
|
||||
|
||||
|
@ -27,4 +28,4 @@ private:
|
|||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82Eltwise_hpp */
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,479 @@
|
|||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82Functions.hpp"
|
||||
#include "Arm82OptFunc.hpp"
|
||||
#include "Arm82WinogradOptFunc.hpp"
|
||||
#include "Arm82Vec.hpp"
|
||||
#include "backend/cpu/compute/CommonOptFunction.h"
|
||||
|
||||
#if defined(MNN_USE_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
// (UP_DIV(l,8), e, 8) -> (UP_DIV(e,eP), l, eP)
|
||||
void Arm82MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
||||
|
||||
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, eP), hP = 24
|
||||
// parameter: [aStride, l, h, cStride, bExtraStride]
|
||||
// aStride in parameter is deprecated (useless), but for code clean, just retain it
|
||||
void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias);
|
||||
|
||||
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24, e >= 1
|
||||
// parameter: [aStride, l, h, cStride, bExtraStride]
|
||||
void MNNPackedMatMulRemainFP16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias);
|
||||
|
||||
void MNNConvDwF23MulTransUnitFP16(FLOAT16 **cacheLine, const FLOAT16 *weight, FLOAT16 *dest, size_t ow);
|
||||
|
||||
void MNNConvDwF23SourceTransUnitFP16(const FLOAT16 *source, FLOAT16 *dest, size_t unit);
|
||||
|
||||
void MNNConvRunForLineDepthwiseFP16(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
|
||||
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, size_t srcHStep, size_t dstHStep);
|
||||
}
|
||||
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
|
||||
namespace MNN {
|
||||
|
||||
static void MNNMatrixAddFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t widthC8, size_t cStride, size_t aStride, size_t bStride, size_t height) {
|
||||
for (int y = 0; y < height; ++y) {
|
||||
auto a = A + aStride * y, b = B + bStride * y;
|
||||
auto c = C + cStride * y;
|
||||
for (int x = 0; x < widthC8; ++x) {
|
||||
vst1q_f16(c + x * 8, vaddq_f16(vld1q_f16(a + x * 8), vld1q_f16(b + x * 8)));
|
||||
}
|
||||
}
|
||||
}
|
||||
static void MNNMatrixSubFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t widthC8, size_t cStride, size_t aStride, size_t bStride, size_t height) {
|
||||
for (int y = 0; y < height; ++y) {
|
||||
auto a = A + aStride * y, b = B + bStride * y;
|
||||
auto c = C + cStride * y;
|
||||
for (int x = 0; x < widthC8; ++x) {
|
||||
vst1q_f16(c + x * 8, vsubq_f16(vld1q_f16(a + x * 8), vld1q_f16(b + x * 8)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t l, bool transpose) {
|
||||
auto dest = (int16_t*)destC;
|
||||
auto source = (int16_t*)sourceC;
|
||||
int ePack, lPack, hPack;
|
||||
Arm82MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
|
||||
auto hP = (int)h / hPack;
|
||||
auto hR = (int)hP * hPack;
|
||||
if (hR != h) {
|
||||
::memset(dest, 0, UP_DIV(h, hPack) * hPack * l * sizeof(FLOAT16));
|
||||
}
|
||||
if (!transpose) {
|
||||
for (int y = 0; y < hP; ++y) {
|
||||
auto destY = dest + y * hPack * l;
|
||||
auto sourceY = source + y * hPack;
|
||||
for (int x = 0; x < l; ++x) {
|
||||
::memcpy(destY + hPack * x, sourceY + x * h, hPack * sizeof(FLOAT16));
|
||||
}
|
||||
}
|
||||
auto hRemain = h - hR;
|
||||
if (hRemain > 0) {
|
||||
auto destY = dest + hP * hPack * l;
|
||||
auto sourceY = source + hP * hPack;
|
||||
for (int x = 0; x < l; ++x) {
|
||||
::memcpy(destY + hPack * x, sourceY + x * h, hRemain * sizeof(FLOAT16));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (int y = 0; y < h; ++y) {
|
||||
for (int x = 0; x < l; ++x) {
|
||||
dest[(y / hPack * l + x) * hPack + y % hPack] = source[y * l + x];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber,
|
||||
size_t biasNumber) {
|
||||
for (int z = 0; z < biasNumber; ++z) {
|
||||
FLOAT16* dstZ = dst + planeNumber * 8 * z;
|
||||
const FLOAT16* srcZ = src + planeNumber * 8 * z;
|
||||
#ifdef MNN_USE_NEON
|
||||
auto biasZ = vld1q_f16(bias + 8 * z), alphaZ = vld1q_f16(alpha + 8 * z);
|
||||
#else
|
||||
auto biasZ = bias + 8 * z, alphaZ = alpha + 8 * z;
|
||||
#endif
|
||||
for (int p = 0; p < planeNumber; ++p) {
|
||||
FLOAT16* dstX = dstZ + 8 * p;
|
||||
const FLOAT16* srcX = srcZ + 8 * p;
|
||||
#ifdef MNN_USE_NEON
|
||||
auto res = vaddq_f16(vmulq_f16(vld1q_f16(srcX), alphaZ), biasZ);
|
||||
vst1q_f16(dstX, res);
|
||||
#else
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
dstX[k] = srcX[k] * alphaZ[k] + biasZ[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNScaleAndAddBiasOutside(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber,
|
||||
size_t biasNumber) {
|
||||
for (size_t p = 0; p < planeNumber; ++p) {
|
||||
FLOAT16* dstPlane = dst + p * biasNumber;
|
||||
const FLOAT16* srcPlane = src + p * biasNumber;
|
||||
for (int z = 0; z < biasNumber; ++z) {
|
||||
dstPlane[z] = srcPlane[z] * alpha[z] + bias[z];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNAddBiasFP16(FLOAT16* dst, const FLOAT16* bias, size_t planeNumber, size_t biasNumber) {
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
for (int i = 0; i < biasNumber; ++i) {
|
||||
auto b = Vec::load(bias + i * 8);
|
||||
for (int j = 0; j < planeNumber; ++j) {
|
||||
auto dstPtr = dst + (i * planeNumber + j) * 8;
|
||||
Vec::save(dstPtr, Vec::load(dstPtr) + b);
|
||||
}
|
||||
}
|
||||
}
|
||||
static void MNNAddBiasReluFP16(FLOAT16* dst, const FLOAT16* bias, size_t planeNumber, size_t biasNumber) {
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
Vec zero((FLOAT16)0);
|
||||
for (int i = 0; i < biasNumber; ++i) {
|
||||
auto b = Vec::load(bias + i * 8);
|
||||
for (int j = 0; j < planeNumber; ++j) {
|
||||
auto dstPtr = dst + (i * planeNumber + j) * 8;
|
||||
auto result = Vec::max(Vec::load(dstPtr) + b, zero);
|
||||
Vec::save(dstPtr, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
static void MNNAddBiasRelu6FP16(FLOAT16* dst, const FLOAT16* bias, size_t planeNumber, size_t biasNumber) {
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
Vec zero((FLOAT16)0), six((FLOAT16)6);
|
||||
for (int i = 0; i < biasNumber; ++i) {
|
||||
auto b = Vec::load(bias + i * 8);
|
||||
for (int j = 0; j < planeNumber; ++j) {
|
||||
auto dstPtr = dst + (i * planeNumber + j) * 8;
|
||||
auto result = Vec::min(Vec::max(Vec::load(dstPtr) + b, zero), six);
|
||||
Vec::save(dstPtr, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNCopyC8WithStrideFP16(const FLOAT16* source, FLOAT16* dest, size_t srcStride, size_t dstStride, size_t count) {
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
auto srcPtr = source + i * srcStride;
|
||||
auto dstPtr = dest + i * dstStride;
|
||||
Vec::save(dstPtr, Vec::load(srcPtr));
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNAddC8WithStrideFP16(const FLOAT16* source, FLOAT16* dest, size_t srcStride, size_t dstStride, size_t count) {
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
auto srcPtr = source + i * srcStride;
|
||||
auto dstPtr = dest + i * dstStride;
|
||||
auto value = Vec::load(dstPtr) + Vec::load(srcPtr);
|
||||
Vec::save(dstPtr, value);
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNAxByClampBroadcastC8FP16(float* CF, const float* AF, const float* BF, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
|
||||
auto C = (FLOAT16*)CF;
|
||||
auto A = (FLOAT16*)AF;
|
||||
auto B = (FLOAT16*)BF;
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
auto minF = Vec(parameters[2]);
|
||||
auto maxF = Vec(parameters[3]);
|
||||
auto beta = Vec(parameters[1]);
|
||||
for (int y = 0; y < height; ++y) {
|
||||
auto a = A + aStride * y;
|
||||
auto b = B + 8 * y;
|
||||
auto bv = Vec::load(b);
|
||||
auto c = C + cStride * y;
|
||||
for (int x = 0; x < width; ++x) {
|
||||
auto av = Vec::load(a + 8 * x);
|
||||
auto cv = av + bv * beta;
|
||||
cv = Vec::min(cv, maxF);
|
||||
cv = Vec::max(cv, minF);
|
||||
Vec::save(c + 8 * x, cv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ARM82MultiAndDestTransformCommon(FLOAT16 **cacheLine, const FLOAT16 *weight, FLOAT16 *dest, int cacheLineSize, int ow) {
|
||||
constexpr int pack = 8;
|
||||
int unit = ow / 2;
|
||||
MNN_ASSERT(cacheLineSize >= 1);
|
||||
for (int x = 0; x < unit; ++x) {
|
||||
int offset = 4 * pack * x, i = 0;
|
||||
Vec m0 = Vec::load(weight + i * 4 * pack) * Vec::load(cacheLine[i] + offset);
|
||||
Vec m1 = Vec::load(weight + (i * 4 + 1) * pack) * Vec::load(cacheLine[i] + offset + pack * 1);
|
||||
Vec m2 = Vec::load(weight + (i * 4 + 2) * pack) * Vec::load(cacheLine[i] + offset + pack * 2);
|
||||
Vec m3 = Vec::load(weight + (i * 4 + 3) * pack) * Vec::load(cacheLine[i] + offset + pack * 3);
|
||||
for (i = 1; i < cacheLineSize; ++i) {
|
||||
m0 = m0 + Vec::load(weight + i * 4 * pack) * Vec::load(cacheLine[i] + offset);
|
||||
m1 = m1 + Vec::load(weight + (i * 4 + 1) * pack) * Vec::load(cacheLine[i] + offset + pack * 1);
|
||||
m2 = m2 + Vec::load(weight + (i * 4 + 2) * pack) * Vec::load(cacheLine[i] + offset + pack * 2);
|
||||
m3 = m3 + Vec::load(weight + (i * 4 + 3) * pack) * Vec::load(cacheLine[i] + offset + pack * 3);
|
||||
}
|
||||
auto o0 = m0 + m1 + m2;
|
||||
auto o1 = m1 - m2 + m3;
|
||||
Vec::save(dest + (2 * x + 0) * pack, o0);
|
||||
Vec::save(dest + (2 * x + 1) * pack, o1);
|
||||
}
|
||||
if (unit * 2 < ow) {
|
||||
int offset = 4 * pack * unit, i = 0;
|
||||
Vec m0 = Vec::load(weight + i * 4 * pack) * Vec::load(cacheLine[i] + offset);
|
||||
Vec m1 = Vec::load(weight + (i * 4 + 1) * pack) * Vec::load(cacheLine[i] + offset + pack);
|
||||
Vec m2 = Vec::load(weight + (i * 4 + 2) * pack) * Vec::load(cacheLine[i] + offset + pack * 2);
|
||||
for (i = 1; i < cacheLineSize; ++i) {
|
||||
m0 = m0 + Vec::load(weight + i * 4 * pack) * Vec::load(cacheLine[i] + offset);
|
||||
m1 = m1 + Vec::load(weight + (i * 4 + 1) * pack) * Vec::load(cacheLine[i] + offset + pack);
|
||||
m2 = m2 + Vec::load(weight + (i * 4 + 2) * pack) * Vec::load(cacheLine[i] + offset + pack * 2);
|
||||
}
|
||||
auto o0 = m0 + m1 + m2;
|
||||
Vec::save(dest + 2 * unit * pack, o0);
|
||||
}
|
||||
}
|
||||
// unit: winograd unit (output is w/2)
|
||||
void ARM82SourceTransformCommon(const FLOAT16 *source, FLOAT16 *dest, int unit, int iw, int pad, int su, int eu) {
|
||||
constexpr int pack = 8; // float16x8
|
||||
for (int x = 0; x < su; ++x) {
|
||||
auto dstX = dest + 4 * pack * x;
|
||||
auto sx = x * 2 - (int)pad;
|
||||
auto ex = sx + 4;
|
||||
auto clampSx = std::max(sx, 0);
|
||||
auto clampEx = std::min(ex, (int)iw);
|
||||
Vec v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
for (int i = clampSx; i < clampEx; ++i) {
|
||||
v[i - sx] = Vec::load(source + pack * i);
|
||||
}
|
||||
auto m0 = v[0] - v[2];
|
||||
auto m1 = v[1] + v[2];
|
||||
auto m2 = v[2] - v[1];
|
||||
auto m3 = v[3] - v[1];
|
||||
Vec::save(dstX + pack * 0, m0);
|
||||
Vec::save(dstX + pack * 1, m1);
|
||||
Vec::save(dstX + pack * 2, m2);
|
||||
Vec::save(dstX + pack * 3, m3);
|
||||
}
|
||||
MNNConvDwF23SourceTransUnitFP16(source + pack * (su * 2 - pad), dest + 4 * pack * su, eu - su);
|
||||
for (int x = eu; x < unit; ++x) {
|
||||
auto dstX = dest + 4 * pack * x;
|
||||
auto sx = x * 2 - (int)pad;
|
||||
auto ex = sx + 4;
|
||||
auto clampSx = std::max(sx, 0);
|
||||
auto clampEx = std::min(ex, (int)iw);
|
||||
Vec v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
for (int i = clampSx; i < clampEx; ++i) {
|
||||
v[i - sx] = Vec::load(source + pack * i);
|
||||
}
|
||||
auto m0 = v[0] - v[2];
|
||||
auto m1 = v[1] + v[2];
|
||||
auto m2 = v[2] - v[1];
|
||||
auto m3 = v[3] - v[1];
|
||||
Vec::save(dstX + pack * 0, m0);
|
||||
Vec::save(dstX + pack * 1, m1);
|
||||
Vec::save(dstX + pack * 2, m2);
|
||||
Vec::save(dstX + pack * 3, m3);
|
||||
}
|
||||
}
|
||||
|
||||
void ARM82StrassenMerge(FLOAT16* c11, FLOAT16* c12, FLOAT16* c21, FLOAT16* c22, FLOAT16* xAddr,
|
||||
size_t cStride, size_t eSub, size_t hSub) {
|
||||
const int pack = 8;
|
||||
for (int y = 0; y < hSub; ++y) {
|
||||
auto c11Y = c11 + y * cStride;
|
||||
auto c12Y = c12 + y * cStride;
|
||||
auto c22Y = c22 + y * cStride;
|
||||
auto c21Y = c21 + y * cStride;
|
||||
auto xY = xAddr + y * eSub * pack;
|
||||
for (int x = 0; x < eSub; ++x) {
|
||||
auto xv = vld1q_f16(xY + x * pack);
|
||||
auto c21v = vld1q_f16(c21Y + x * pack);
|
||||
auto c11v = vld1q_f16(c11Y + x * pack);
|
||||
auto c22v = vld1q_f16(c22Y + x * pack);
|
||||
auto c12v = vld1q_f16(c12Y + x * pack);
|
||||
c12v = c12v + xv;
|
||||
c21v = c12v + c21v;
|
||||
c12v = c22v + c12v;
|
||||
c22v = c22v + c21v;
|
||||
c12v = c11v + c12v;
|
||||
vst1q_f16(c12Y + x * pack, c12v);
|
||||
vst1q_f16(c22Y + x * pack, c22v);
|
||||
vst1q_f16(c21Y + x * pack, c21v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth) {
|
||||
if (1 == area) {
|
||||
::memcpy(dst, src, depth * sizeof(int16_t));
|
||||
return;
|
||||
}
|
||||
int c = (int)depth;
|
||||
int cDiv4 = c / 8;
|
||||
int cAlign = cDiv4 * 8;
|
||||
if (cAlign == c) {
|
||||
for (int hi = 0; hi < area; ++hi) {
|
||||
auto srcHeight = src + hi * 8;
|
||||
auto dstHeight = dst + hi * cDiv4 * 8;
|
||||
for (int ci = 0; ci < cDiv4; ++ci) {
|
||||
vst1q_s16(dstHeight + ci * 8, vld1q_s16(srcHeight + 8 * ci * area));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int hi = 0; hi < area; ++hi) {
|
||||
auto srcHeight = src + hi * 8;
|
||||
auto dstHeight = dst + hi * c;
|
||||
for (int ci = 0; ci < cDiv4; ++ci) {
|
||||
vst1q_s16(dstHeight + ci * 8, vld1q_s16(srcHeight + 8 * ci * area));
|
||||
}
|
||||
}
|
||||
|
||||
int cReamin = c - cAlign;
|
||||
auto srcAlign = src + area * cAlign;
|
||||
auto dstAlign = dst + cAlign;
|
||||
|
||||
for (int hi = 0; hi < area; ++hi) {
|
||||
auto srcHeight = srcAlign + hi * 8;
|
||||
auto dstHeight = dstAlign + hi * c;
|
||||
|
||||
for (int ci = 0; ci < cReamin; ++ci) {
|
||||
dstHeight[ci] = srcHeight[ci];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth) {
|
||||
if (depth == 8) {
|
||||
::memcpy(dst, src, area * depth * sizeof(int16_t));
|
||||
return;
|
||||
}
|
||||
int c = (int)depth;
|
||||
int cDiv4 = c / 8;
|
||||
int cAlign = cDiv4 * 8;
|
||||
for (int hi = 0; hi < area; ++hi) {
|
||||
auto srcHeight = (src + hi * c);
|
||||
auto dstHeight = (dst + hi * 8);
|
||||
for (int ci = 0; ci < cDiv4; ++ci) {
|
||||
vst1q_s16(dstHeight + ci * area * 8, vld1q_s16(srcHeight + 8 * ci));
|
||||
}
|
||||
}
|
||||
|
||||
if (cAlign == c) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cReamin = c - cAlign;
|
||||
auto srcAlign = src + cAlign;
|
||||
auto dstAlign = dst + area * cAlign;
|
||||
|
||||
for (int hi = 0; hi < area; ++hi) {
|
||||
auto srcHeight = srcAlign + hi * c;
|
||||
auto dstHeight = dstAlign + hi * 8;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
dstHeight[i] = 0;
|
||||
}
|
||||
for (int ci = 0; ci < cReamin; ++ci) {
|
||||
dstHeight[ci] = srcHeight[ci];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNConvRunForUnitDepthWiseFP16(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
|
||||
size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
|
||||
int fx, fy;
|
||||
Vec dstValue(0.0f);
|
||||
auto src_z = (const FLOAT16*)src;
|
||||
auto weight_z = (const FLOAT16*)weight;
|
||||
for (fy = 0; fy < fh; ++fy) {
|
||||
auto src_y = src_z + fy * dilateY_step;
|
||||
auto weight_y = weight_z + fy * weight_y_step;
|
||||
for (fx = 0; fx < fw; ++fx) {
|
||||
auto weight_x = weight_y + 8 * fx;
|
||||
auto src_x = src_y + fx * dilateX_step;
|
||||
dstValue = dstValue + Vec::load(src_x) * Vec::load(weight_x);
|
||||
}
|
||||
}
|
||||
Vec::save((FLOAT16*)dst, dstValue);
|
||||
}
|
||||
|
||||
static void _MNNDeconvRunForUnitDepthWise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t fw, size_t fh,
|
||||
size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
|
||||
int fx, fy;
|
||||
auto src_z = src;
|
||||
auto weight_z = weight;
|
||||
Vec dstV = Vec::load(dst);
|
||||
for (fy = 0; fy < fh; ++fy) {
|
||||
auto src_y = src_z + fy * dilateY_step;
|
||||
auto weight_y = weight_z + fy * weight_y_step;
|
||||
for (fx = 0; fx < fw; ++fx) {
|
||||
Vec weight_x = Vec::load(weight_y + 8 * fx);
|
||||
Vec src_x = Vec::load(src_y + fx * dilateX_step);
|
||||
Vec::save(src_y + fx * dilateX_step, src_x + weight_x * dstV);
|
||||
}
|
||||
}
|
||||
}
|
||||
static void _MNNDeconvRunForLineDepthwise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t width, size_t src_w_setup,
|
||||
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
|
||||
int dx;
|
||||
for (dx = 0; dx < width; ++dx) {
|
||||
auto dst_x = dst + dx * 8;
|
||||
auto src_dx = src + src_w_setup * dx;
|
||||
_MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * 8, dilateX_step, dilateY_step);
|
||||
}
|
||||
}
|
||||
|
||||
static CoreFunctions* gInstance = nullptr;
|
||||
bool Arm82Functions::init() {
|
||||
#define FUNC_PTR_ASSIGN(dst, src) dst = (decltype(dst))src
|
||||
gInstance = new CoreFunctions;
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNFp32ToLowp, MNNQuantizeFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNLowpToFp32, MNNDequantizeFP16);
|
||||
gInstance->bytes = 2;
|
||||
|
||||
// Packed
|
||||
gInstance->pack = 8;
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackCUnit, MNNPackC8FP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNUnpackCUnit, MNNUnPackC8FP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackCUnitTranspose, MNNPackTransposeInt16C8);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNUnpackCUnitTranspose, MNNUnpackTransposeInt16C8);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNConvRunForUnitDepthWise, MNNConvRunForUnitDepthWiseFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNConvRunForLineDepthwise, MNNConvRunForLineDepthwiseFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNAxByClampBroadcastUnit, MNNAxByClampBroadcastC8FP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNConvDwF23MulTransUnit, MNNConvDwF23MulTransUnitFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNSourceTransformCommonF23, ARM82SourceTransformCommon);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNMultiAndDestTransformCommon23, ARM82MultiAndDestTransformCommon);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNMatrixSub, MNNMatrixSubFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNMatrixAdd, MNNMatrixAddFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge);
|
||||
gInstance->penalty = 2.0f;
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNCopyC4WithStride, MNNCopyC8WithStrideFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);
|
||||
|
||||
// MatMul
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
|
||||
|
||||
FUNC_PTR_ASSIGN(gInstance->chooseWinoSourceTransform, Arm82WinogradFunction::chooseSourceTransform);
|
||||
FUNC_PTR_ASSIGN(gInstance->chooseWinoDestTransform, Arm82WinogradFunction::chooseDestTransform);
|
||||
|
||||
gInstance->MNNDeconvRunForLineDepthwise = (decltype(gInstance->MNNDeconvRunForLineDepthwise))_MNNDeconvRunForLineDepthwise;
|
||||
gInstance->MNNDeconvRunForUnitDepthWise = (decltype(gInstance->MNNDeconvRunForUnitDepthWise))_MNNDeconvRunForUnitDepthWise;
|
||||
return true;
|
||||
}
|
||||
|
||||
CoreFunctions* Arm82Functions::get() {
|
||||
return gInstance;
|
||||
}
|
||||
};
|
||||
#endif
|
|
@ -0,0 +1,20 @@
|
|||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Functions_hpp
|
||||
#define Arm82Functions_hpp
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "core/Macro.h"
|
||||
#include "backend/cpu/CPUBackend.hpp"
|
||||
namespace MNN {
|
||||
class Arm82Functions {
|
||||
public:
|
||||
static bool init();
|
||||
static CoreFunctions* get();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
#endif // Arm82Functions_hpp
|
||||
#endif
|
|
@ -0,0 +1,107 @@
|
|||
//
|
||||
// Arm82InstanceNorm.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/02/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "Arm82OptFunc.hpp"
|
||||
#include "Arm82InstanceNorm.hpp"
|
||||
#include "MNN_generated.h"
|
||||
#include "core/Concurrency.h"
|
||||
#include <MNN/MNNDefine.h>
|
||||
#include <cmath>
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
|
||||
Arm82InstanceNorm::Arm82InstanceNorm(Backend* backend, const MNN::Op* op) : Execution(backend) {
|
||||
auto normParam = op->main_as_BatchNorm();
|
||||
const int channels = normParam->channels();
|
||||
mEpsilon = normParam->epsilon();
|
||||
mScale.reset(ALIGN_UP8(channels));
|
||||
mScale.clear();
|
||||
if (normParam->slopeData() && normParam->slopeData()->data()) {
|
||||
MNNSlowCopy<FLOAT16, float>(mScale.get(), normParam->slopeData()->data(), channels);
|
||||
}
|
||||
|
||||
mBias.reset(ALIGN_UP8(channels));
|
||||
mBias.clear();
|
||||
if (normParam->biasData() && normParam->biasData()->data()) {
|
||||
MNNSlowCopy<FLOAT16, float>(mBias.get(), normParam->biasData()->data(), channels);
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode Arm82InstanceNorm::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
MNN_ASSERT(3 == inputs.size());
|
||||
MNN_ASSERT(1 == outputs.size());
|
||||
|
||||
auto input = inputs[0], mean = inputs[1], variance = inputs[2], output = outputs[0];
|
||||
const int batch = input->batch(), imageSize = input->stride(1);
|
||||
auto scalePtr = mScale.get(), biasPtr = mBias.get();
|
||||
const int threadNum = ((Arm82Backend*)backend())->numberThread();
|
||||
const int channelBlock = UP_DIV(input->channel(), 8);
|
||||
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
auto inputPtr = input->host<FLOAT16>() + b * ARM82TensorStrideHelper(input, 0);
|
||||
auto meanPtr = mean->host<FLOAT16>() + b * ARM82TensorStrideHelper(mean, 0);
|
||||
auto variancePtr = variance->host<FLOAT16>() + b * ARM82TensorStrideHelper(variance, 0);
|
||||
auto outputPtr = output->host<FLOAT16>() + b * ARM82TensorStrideHelper(output, 0);
|
||||
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
const int step = UP_DIV(channelBlock, threadNum) * 8, start = tId * step, end = ALIMIN(start + step, channelBlock);
|
||||
for (int c = start; c < end; c += 8) {
|
||||
auto inputPtrZ = inputPtr + c * imageSize;
|
||||
auto outputPtrZ = outputPtr + c * imageSize;
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t meanVec = vld1q_f16(meanPtr + c), varVec = vld1q_f16(variancePtr + c);
|
||||
float16x8_t scaleVec = vld1q_f16(scalePtr + c), biasVec = vld1q_f16(biasPtr + c);
|
||||
float16x8_t epsVec = vdupq_n_f16(mEpsilon), rsqrtVec = vrsqrteq_f16(varVec + epsVec);
|
||||
|
||||
float16x8_t gamma = vmulq_f16(scaleVec, rsqrtVec);
|
||||
float16x8_t beta = vsubq_f16(biasVec, vmulq_f16(meanVec, gamma));
|
||||
for (int i = 0; i < imageSize; ++i) {
|
||||
float16x8_t in = vld1q_f16(inputPtr + i * 8);
|
||||
vst1q_f16(outputPtrZ + i * 8, vaddq_f16(vmulq_f16(in, gamma), beta));
|
||||
}
|
||||
#else
|
||||
FLOAT16 gamma[8], beta[8];
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
int index = c + k;
|
||||
gamma[k] = scalePtr[index] / sqrt(variancePtr[index] + mEpsilon);
|
||||
beta[k] = biasPtr[index] - gamma[k] * meanPtr[index];
|
||||
}
|
||||
for (int i = 0; i < imageSize; ++i) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
outputPtrZ[i * 8 + k] = inputPtrZ[i * 8 + k] * gamma[k] + beta[k];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
class Arm82InstanceNormCreator : public Arm82Backend::Arm82Creator {
|
||||
public:
|
||||
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
||||
const MNN::Op* op, Backend* backend) const override {
|
||||
return new Arm82InstanceNorm(backend, op);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_ARM82_OP_CREATOR(OpType_InstanceNorm, Arm82InstanceNormCreator);
|
||||
|
||||
} // namespace MNN
|
||||
#endif
|
|
@ -0,0 +1,33 @@
|
|||
//
|
||||
// Arm82InstanceNorm.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/02/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82InstanceNorm_hpp
|
||||
#define Arm82InstanceNorm_hpp
|
||||
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/AutoStorage.h"
|
||||
#include "core/Execution.hpp"
|
||||
#include "MNN_generated.h"
|
||||
|
||||
namespace MNN {
|
||||
class Arm82InstanceNorm : public Execution {
|
||||
public:
|
||||
Arm82InstanceNorm(Backend *backend, const MNN::Op *op);
|
||||
virtual ~Arm82InstanceNorm() = default;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
AutoStorage<FLOAT16> mScale;
|
||||
AutoStorage<FLOAT16> mBias;
|
||||
FLOAT16 mEpsilon;
|
||||
};
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82InstanceNorm_hpp */
|
||||
#endif
|
|
@ -5,8 +5,9 @@
|
|||
// Created by MNN on 2020/04/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#include "backend/arm82/Arm82Interp.hpp"
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82Interp.hpp"
|
||||
#include <math.h>
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/Macro.h"
|
||||
|
|
|
@ -5,11 +5,12 @@
|
|||
// Created by MNN on 2020/04/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef CPUInterp_hpp
|
||||
#define CPUInterp_hpp
|
||||
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/AutoStorage.h"
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
|
@ -38,3 +39,4 @@ private:
|
|||
} // namespace MNN
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
//
|
||||
// Arm82Moments.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/02/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82Moments.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "Arm82Vec.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include <MNN/MNNDefine.h>
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
namespace MNN {
|
||||
|
||||
Arm82Moments::Arm82Moments(Backend *backend, const MNN::Op *op) : Execution(backend) {
|
||||
auto momentsParam = op->main_as_MomentsParam();
|
||||
if (momentsParam->dim()) {
|
||||
for (int i = 0; i < momentsParam->dim()->size(); ++i) {
|
||||
mAxis.push_back(momentsParam->dim()->data()[i]);
|
||||
}
|
||||
}
|
||||
mKeepDims = momentsParam->keepDims();
|
||||
MNN_ASSERT(DataType_DT_FLOAT == momentsParam->dType());
|
||||
}
|
||||
|
||||
ErrorCode Arm82Moments::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
void Arm82Moments::calculateMean(const FLOAT16 *src, FLOAT16 *mean, int channelBlock, int planeNumber) {
|
||||
const int numberThread = ((Arm82Backend*)backend())->numberThread();
|
||||
MNN_CONCURRENCY_BEGIN(tId, numberThread) {
|
||||
int step = UP_DIV(channelBlock, numberThread), start = tId * step, end = ALIMIN(start + step, channelBlock);
|
||||
for (int z = start; z < end; ++z) {
|
||||
const FLOAT16* srcZ = src + z * planeNumber * 8;
|
||||
FLOAT16* meanZ = mean + z * 8;
|
||||
|
||||
Vec sum(0);
|
||||
for (int i = 0; i < planeNumber; ++i) {
|
||||
sum = sum + Vec::load(srcZ + i * 8);
|
||||
}
|
||||
Vec result = sum / (float)planeNumber;
|
||||
Vec::save(meanZ, result);
|
||||
}
|
||||
|
||||
} MNN_CONCURRENCY_END();
|
||||
}
|
||||
|
||||
void Arm82Moments::calculateVariance(const FLOAT16 *src, const FLOAT16 *mean, FLOAT16* var, int channelBlock, int planeNumber) {
|
||||
const int numberThread = ((Arm82Backend*)backend())->numberThread();
|
||||
MNN_CONCURRENCY_BEGIN(tId, numberThread) {
|
||||
int step = UP_DIV(channelBlock, numberThread), start = tId * step, end = ALIMIN(start + step, channelBlock);
|
||||
for (int z = start; z < end; ++z) {
|
||||
const FLOAT16* srcZ = src + z * planeNumber * 8, *meanZ = mean + z * 8;
|
||||
FLOAT16* varZ = var + z * 8;
|
||||
|
||||
Vec sum(0), meanVal = Vec::load(meanZ);
|
||||
for (int i = 0; i < planeNumber; ++i) {
|
||||
Vec diff = Vec::load(srcZ + i * 8) - meanVal;
|
||||
sum = sum + diff * diff;
|
||||
}
|
||||
Vec result = sum / (float)planeNumber;
|
||||
Vec::save(varZ, result);
|
||||
}
|
||||
|
||||
} MNN_CONCURRENCY_END();
|
||||
}
|
||||
|
||||
ErrorCode Arm82Moments::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
MNN_ASSERT(1 == inputs.size());
|
||||
MNN_ASSERT(2 == outputs.size());
|
||||
auto input = inputs[0], mean = outputs[0], variance = outputs[1];
|
||||
|
||||
// the layout of Moments is NC4HW4, now only support for calculating Moments along height and width
|
||||
MNN_ASSERT(MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(input)->dimensionFormat);
|
||||
MNN_ASSERT(mKeepDims);
|
||||
MNN_ASSERT(mAxis.size() == 2 && mAxis[0] == 2 && mAxis[1] == 3);
|
||||
|
||||
const int batch = input->batch(), channelBlock = UP_DIV(mean->channel(), 8);
|
||||
const int inBatchStride = ARM82TensorStrideHelper(input, 0), outBatchStride = ARM82TensorStrideHelper(mean, 0);
|
||||
const int planeNumber = ARM82TensorStrideHelper(input, 1);
|
||||
// mean
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
const FLOAT16* srcPtr = input->host<FLOAT16>() + b * inBatchStride;
|
||||
FLOAT16* meanPtr = mean->host<FLOAT16>() + b * outBatchStride;
|
||||
calculateMean(srcPtr, meanPtr, channelBlock, planeNumber);
|
||||
}
|
||||
// variance
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
const FLOAT16* srcPtr = input->host<FLOAT16>() + b * inBatchStride;
|
||||
const FLOAT16* meanPtr = mean->host<FLOAT16>() + b * outBatchStride;
|
||||
FLOAT16* variancePtr = variance->host<FLOAT16>() + b * outBatchStride;
|
||||
calculateVariance(srcPtr, meanPtr, variancePtr, channelBlock, planeNumber);
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
class Arm82MomentsCreator : public Arm82Backend::Arm82Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||
const MNN::Op *op, Backend *backend) const override {
|
||||
return new Arm82Moments(backend, op);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_ARM82_OP_CREATOR(OpType_Moments, Arm82MomentsCreator);
|
||||
|
||||
} // namespace MNN
|
||||
#endif
|
|
@ -0,0 +1,35 @@
|
|||
//
|
||||
// Arm82Moments.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/02/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Moments_hpp
|
||||
#define Arm82Moments_hpp
|
||||
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
|
||||
class Arm82Moments : public Execution {
|
||||
public:
|
||||
Arm82Moments(Backend* backend, const MNN::Op* op);
|
||||
virtual ~Arm82Moments() = default;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override;
|
||||
virtual ErrorCode onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override;
|
||||
|
||||
private:
|
||||
void calculateMean(const FLOAT16 *src, FLOAT16 *mean, int channelBlock, int planeNumber);
|
||||
void calculateVariance(const FLOAT16 *src, const FLOAT16 *mean, FLOAT16* var, int channelBlock, int planeNumber);
|
||||
std::vector<int> mAxis;
|
||||
bool mKeepDims;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82Moments_hpp */
|
||||
#endif
|
|
@ -1,26 +1,28 @@
|
|||
// This file is generated by Shell for ops register
|
||||
namespace MNN {
|
||||
extern void ___OpType_ConvolutionDepthwise__Arm82ConvolutionDepthwiseCreator__();
|
||||
extern void ___OpType_Moments__Arm82MomentsCreator__();
|
||||
extern void ___OpType_Raster__Arm82RasterFactory__();
|
||||
extern void ___OpType_Pooling__Arm82PoolingCreator__();
|
||||
extern void ___OpType_InstanceNorm__Arm82InstanceNormCreator__();
|
||||
extern void ___OpType_Eltwise__Arm82EltwiseCreator__();
|
||||
extern void ___OpType_ReLU__Arm82ReluCreator__();
|
||||
extern void ___OpType_PReLU__Arm82ReluCreator__();
|
||||
extern void ___OpType_BinaryOp__Arm82BinaryCreator__();
|
||||
extern void ___OpType_Interp__Arm82InterpCreator__();
|
||||
extern void ___OpType_Convolution__Arm82ConvolutionCreator__();
|
||||
extern void ___OpType_UnaryOp__Arm82UnaryCreator__();
|
||||
|
||||
void registerArm82Ops() {
|
||||
#ifdef __aarch64__
|
||||
___OpType_ConvolutionDepthwise__Arm82ConvolutionDepthwiseCreator__();
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
___OpType_Moments__Arm82MomentsCreator__();
|
||||
___OpType_Raster__Arm82RasterFactory__();
|
||||
___OpType_Pooling__Arm82PoolingCreator__();
|
||||
___OpType_InstanceNorm__Arm82InstanceNormCreator__();
|
||||
___OpType_Eltwise__Arm82EltwiseCreator__();
|
||||
___OpType_ReLU__Arm82ReluCreator__();
|
||||
___OpType_PReLU__Arm82ReluCreator__();
|
||||
___OpType_BinaryOp__Arm82BinaryCreator__();
|
||||
___OpType_Interp__Arm82InterpCreator__();
|
||||
___OpType_Convolution__Arm82ConvolutionCreator__();
|
||||
___OpType_UnaryOp__Arm82UnaryCreator__();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,27 +5,71 @@
|
|||
// Created by MNN on 2019/02/06.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#include "backend/arm82/Arm82OptFunc.hpp"
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82OptFunc.hpp"
|
||||
#include "Arm82Vec.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "half.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
void MNNQuantizeFP16(FLOAT16* dst, const float* src, int size) {
|
||||
int sizeDiv4 = size / 4;
|
||||
int remain = size - sizeDiv4 * 4;
|
||||
#endif
|
||||
|
||||
if (sizeDiv4 > 0) {
|
||||
MNNQuantizeFP16_UNIT4(dst, src, sizeDiv4);
|
||||
extern "C" {
|
||||
void MNNExpFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* params, size_t blockCount);
|
||||
|
||||
void MNNQuantizeFP16_UNIT4(int16_t* dst, const float* src, int size);
|
||||
|
||||
}
|
||||
|
||||
void Arm82MNNExp(FLOAT16* dst, const FLOAT16* src, size_t dataSize) {
|
||||
int blockCount = dataSize / 16;
|
||||
if (blockCount > 0) {
|
||||
static FLOAT16 params[] = {
|
||||
(FLOAT16)log(2.0f), (FLOAT16)(1.0f / log(2.0f)), 1.0f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
|
||||
MNNExpFP16(dst, src, params, blockCount);
|
||||
}
|
||||
|
||||
if (remain > 0) {
|
||||
for (int i = sizeDiv4 * 4; i < size; ++i) {
|
||||
dst[i] = half_float::half(src[i]);
|
||||
}
|
||||
FLOAT16 xLimit = 11, expStep = log(2.0f), expStep_r = 1.0f / expStep;
|
||||
for (int i = blockCount * 16; i < dataSize; ++i) {
|
||||
auto x = -src[i];
|
||||
x = ALIMAX(x, -xLimit);
|
||||
x = ALIMIN(x, xLimit);
|
||||
int div = x * expStep_r, expBasicRaw = (div + 15) << 10;
|
||||
FLOAT16 t = x - div * expStep, expBasic = *(FLOAT16*)(&expBasicRaw);
|
||||
FLOAT16 expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
|
||||
dst[i] = (FLOAT16)(expBasic * expRemain);
|
||||
}
|
||||
}
|
||||
|
||||
void MNNDequantizeFP16(float* dst, const int16_t* srcint, int size) {
|
||||
void Arm82MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
||||
#ifdef __aarch64__
|
||||
*hP = 16;
|
||||
#else
|
||||
*hP = 8;
|
||||
#endif
|
||||
*eP = 12;
|
||||
*lP = 1;
|
||||
}
|
||||
|
||||
void MNNQuantizeFP16(const float* src, int16_t* dst, size_t size) {
|
||||
int sizeDiv4 = size / 4;
|
||||
int remain = size - sizeDiv4 * 4;
|
||||
if (sizeDiv4 > 0) {
|
||||
MNNQuantizeFP16_UNIT4(dst, src, sizeDiv4);
|
||||
src += sizeDiv4 * 4;
|
||||
dst += sizeDiv4 * 4;
|
||||
}
|
||||
if (remain > 0) {
|
||||
float tempSrc[4];
|
||||
int16_t tempDst[4];
|
||||
::memcpy(tempSrc, src, remain * sizeof(float));
|
||||
MNNQuantizeFP16_UNIT4(tempDst, tempSrc, 1);
|
||||
::memcpy(dst, tempDst, remain * sizeof(int16_t));
|
||||
}
|
||||
}
|
||||
|
||||
void MNNDequantizeFP16(const int16_t* srcint, float* dst, size_t size) {
|
||||
auto src = (const FLOAT16*)srcint;
|
||||
int sizeDiv4 = size / 4;
|
||||
int remain = size - sizeDiv4 * 4;
|
||||
|
@ -47,10 +91,18 @@ void MNNDequantizeFP16(float* dst, const int16_t* srcint, int size) {
|
|||
}
|
||||
}
|
||||
|
||||
void MNNNC4HW4TONC8HW8(uint16_t* dst, const float* source, size_t plane, size_t channel) {
|
||||
void MNNPackC8FP16(FLOAT16* dest, const FLOAT16* source, size_t plane, size_t channel) {
|
||||
MNNPackUNIT<FLOAT16, FLOAT16, 8>(dest, source, plane, channel);
|
||||
}
|
||||
|
||||
void MNNUnPackC8FP16(FLOAT16* dest, const FLOAT16* source, size_t plane, size_t channel) {
|
||||
MNNUnpackUNIT<FLOAT16, FLOAT16, 8>(dest, source, plane, channel);
|
||||
}
|
||||
|
||||
void MNNNC4HW4TONC8HW8(FLOAT16* dst, const float* source, size_t plane, size_t channel) {
|
||||
const int c4 = UP_DIV(channel, 4);
|
||||
const int c8 = UP_DIV(channel, 8);
|
||||
memset(dst, 0, plane * c8 * 8 * sizeof(uint16_t));
|
||||
memset(dst, 0, plane * c8 * 8 * sizeof(FLOAT16));
|
||||
#if defined(MNN_USE_NEON) && defined(__aarch64__)
|
||||
auto dest = (float16_t*)dst;
|
||||
#else
|
||||
|
@ -78,7 +130,7 @@ void MNNNC4HW4TONC8HW8(uint16_t* dst, const float* source, size_t plane, size_t
|
|||
}
|
||||
}
|
||||
|
||||
void MNNNC8HW8TONC4HW4(float* dest, const uint16_t* src, size_t plane, size_t channel) {
|
||||
void MNNNC8HW8TONC4HW4(float* dest, const FLOAT16* src, size_t plane, size_t channel) {
|
||||
const int c4 = UP_DIV(channel, 4);
|
||||
#if defined(MNN_USE_NEON) && defined(__aarch64__)
|
||||
auto source = (float16_t*)src;
|
||||
|
@ -106,7 +158,7 @@ void MNNNC8HW8TONC4HW4(float* dest, const uint16_t* src, size_t plane, size_t ch
|
|||
}
|
||||
}
|
||||
|
||||
void MNNNC8HW8TONHWC(float* dest, const uint16_t* src, size_t plane, size_t channel) {
|
||||
void MNNNC8HW8TONHWC(float* dest, const FLOAT16* src, size_t plane, size_t channel) {
|
||||
int c = (int)channel;
|
||||
int cDiv8 = c / 8;
|
||||
int cAlign = cDiv8 * 8;
|
||||
|
@ -115,32 +167,28 @@ void MNNNC8HW8TONHWC(float* dest, const uint16_t* src, size_t plane, size_t chan
|
|||
#else
|
||||
auto source = src;
|
||||
#endif
|
||||
|
||||
for (int hi = 0; hi < plane; ++hi) {
|
||||
const auto srcHeight = source + hi * 8;
|
||||
float* dstHeight = dest + hi * c;
|
||||
for (int ci = 0; ci < cDiv8; ++ci) {
|
||||
#ifdef MNN_USE_NEON
|
||||
#if defined(MNN_USE_NEON) && defined(__aarch64__)
|
||||
float16x8_t a = vld1q_f16(srcHeight + 8 * ci * plane);
|
||||
vst1q_f32(dstHeight + 8 * ci, vcvt_high_f32_f16(a));
|
||||
#else
|
||||
half_float::half dataHalf[8];
|
||||
memcpy(dataHalf, srcHeight + 8 * ci * plane, 8 * sizeof(uint16_t));
|
||||
memcpy(dataHalf, srcHeight + 8 * ci * plane, 8 * sizeof(FLOAT16));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
dstHeight[ci * 8 + i] = float(dataHalf[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (cAlign == c) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cReamin = c - cAlign;
|
||||
const auto srcAlign = reinterpret_cast<const half_float::half*>(source + plane * cAlign);
|
||||
auto dstAlign = dest + cAlign;
|
||||
|
||||
for (int hi = 0; hi < plane; ++hi) {
|
||||
const auto srcHeight = srcAlign + hi * 8;
|
||||
float* dstHeight = dstAlign + hi * c;
|
||||
|
@ -150,23 +198,4 @@ void MNNNC8HW8TONHWC(float* dest, const uint16_t* src, size_t plane, size_t chan
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MNNNCHWTONC8HW8(uint16_t* dest, const float* source, size_t plane, size_t channel) {
|
||||
auto halfDest = reinterpret_cast<half_float::half*>(dest);
|
||||
MNNPackUNIT<float, half_float::half, 8>(halfDest, source, plane, channel);
|
||||
}
|
||||
|
||||
void MNNNC8HW8TONCHW(float* dest, const uint16_t* source, size_t plane, size_t channel) {
|
||||
auto halfSrc = reinterpret_cast<const half_float::half*>(source);
|
||||
MNNUnpackUNIT<half_float::half, float, 8>(dest, halfSrc, plane, channel);
|
||||
}
|
||||
|
||||
void MNNNCHWTONC8HW8_NO_TYPE(uint16_t* dest, const uint16_t* source, size_t plane, size_t channel) {
|
||||
MNNPackUNIT<uint16_t, uint16_t, 8>(dest, source, plane, channel);
|
||||
}
|
||||
|
||||
void MNNNC8HW8TONCHW_NO_TYPE(uint16_t* dest, const uint16_t* source, size_t plane, size_t channel) {
|
||||
MNNUnpackUNIT<uint16_t, uint16_t, 8>(dest, source, plane, channel);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -5,116 +5,61 @@
|
|||
// Created by MNN on 2019/02/06.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82OptFunc_hpp
|
||||
#define Arm82OptFunc_hpp
|
||||
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/Macro.h"
|
||||
|
||||
#define DST_XUNIT 8
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void MNNGemmFP16C8_UNIT(FLOAT16* dst, const FLOAT16* src, const FLOAT16* weight, const FLOAT16* bias, size_t src_loop,
|
||||
size_t dst_step, size_t dst_loop, size_t relu, size_t relu6, size_t realDstCount);
|
||||
|
||||
void MNNShuffleChannelC8(FLOAT16* dst, const FLOAT16* src, size_t size, size_t halfFlag);
|
||||
void MNNQuantizeFP16_UNIT4(FLOAT16* dst, const float* src, int size);
|
||||
void MNNDequantizeFP16(float* dst, const int16_t* src, int size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
void MNNQuantizeFP16(FLOAT16* dst, const float* src, int size);
|
||||
|
||||
void Arm82MNNGetMatMulPackMode(int* eP, int *lP, int* hP);
|
||||
void Arm82MNNExp(FLOAT16* dst, const FLOAT16* src, size_t dataSize);
|
||||
void MNNQuantizeFP16(const float* src, int16_t* dst, size_t size);
|
||||
void MNNDequantizeFP16(const int16_t* src, float* dst, size_t size);
|
||||
void MNNPackC8FP16(FLOAT16* dest, const FLOAT16* source, size_t area, size_t depth);
|
||||
void MNNUnPackC8FP16(FLOAT16* dest, const FLOAT16* source, size_t area, size_t depth);
|
||||
// nc4hw4 to nc8hw8(aka fp32 -> fp16), convete dataformat and data type
|
||||
void MNNNC4HW4TONC8HW8(uint16_t* dest, const float* source, size_t plane, size_t channel);
|
||||
void MNNNC4HW4TONC8HW8(FLOAT16* dest, const float* source, size_t plane, size_t channel);
|
||||
// nc8hw8 to nc4hw4(aka fp16 -> fp32)
|
||||
void MNNNC8HW8TONC4HW4(float* dest, const uint16_t* source, size_t plane, size_t channel);
|
||||
// nchw to nc8hw8(aka fp32 -> fp16)
|
||||
void MNNNCHWTONC8HW8(uint16_t* dest, const float* source, size_t plane, size_t channel);
|
||||
// nc8hw8 to nchw(aka fp16 -> fp32)
|
||||
void MNNNC8HW8TONCHW(float* dest, const uint16_t* source, size_t plane, size_t channel);
|
||||
|
||||
void MNNNC8HW8TONHWC(float* dest, const uint16_t* src, size_t plane, size_t channel);
|
||||
|
||||
void MNNNCHWTONC8HW8_NO_TYPE(uint16_t* dest, const uint16_t* source, size_t plane, size_t channel);
|
||||
void MNNNC8HW8TONCHW_NO_TYPE(uint16_t* dest, const uint16_t* source, size_t plane, size_t channel);
|
||||
void MNNNC8HW8TONC4HW4(float* dest, const FLOAT16* source, size_t plane, size_t channel);
|
||||
|
||||
template <typename TIN, typename TOUT, int UNIT>
|
||||
void MNNPackUNIT(TOUT* dst, const TIN* src, size_t area, size_t depth) {
|
||||
int depthCUnit = depth / UNIT;
|
||||
int depthRemain = depthCUnit * UNIT;
|
||||
int remain = depth - depthRemain;
|
||||
int z, x, y;
|
||||
const TIN* srcChannel[UNIT];
|
||||
const TIN* srcOffset = src;
|
||||
for(z = 0; z < depthCUnit; ++z) {
|
||||
for(y = 0; y < UNIT; ++y) {
|
||||
srcChannel[y] = srcOffset + area * y;
|
||||
}
|
||||
for(x = 0; x < area; ++x) {
|
||||
for(y = 0; y < UNIT; ++y) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y]++;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
srcOffset += area * UNIT;
|
||||
}
|
||||
if(remain > 0){
|
||||
for(y = 0; y < remain; ++y) {
|
||||
srcChannel[y] = srcOffset + area * y;
|
||||
}
|
||||
for(x = 0; x < area; ++x) {
|
||||
for(y = 0; y < remain; ++y) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y]++;
|
||||
dst++;
|
||||
}
|
||||
for(y = remain; y < UNIT; ++y) {
|
||||
dst[0] = 0;
|
||||
dst++;
|
||||
}
|
||||
int z, x;
|
||||
int cur = 0;
|
||||
memset(dst, 0, area * UP_DIV(depth, UNIT) * UNIT * sizeof(TOUT));
|
||||
for (z = 0; z < depth; ++z) {
|
||||
int plane = z / UNIT;
|
||||
TOUT* dstPlane = plane * area * UNIT + dst;
|
||||
int offset = z % UNIT;
|
||||
for (x = 0; x < area; ++x) {
|
||||
dstPlane[UNIT * x + offset] = TOUT(src[cur++]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TIN, typename TOUT, int UNIT>
|
||||
void MNNUnpackUNIT(TOUT* dst, const TIN* src, size_t area, size_t depth) {
|
||||
int depthCUnit = depth / UNIT;
|
||||
int depthRemain = depthCUnit * UNIT;
|
||||
int remain = depth - depthRemain;
|
||||
int z, x, y;
|
||||
const TIN* srcChannel[UNIT];
|
||||
const TIN* srcOffset = src;
|
||||
for(z = 0; z < depthCUnit; ++z) {
|
||||
for(y = 0; y < UNIT; ++y) {
|
||||
srcChannel[y] = srcOffset + y;
|
||||
for(x = 0; x < area; ++x) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y] += UNIT;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
srcOffset += area * UNIT;
|
||||
}
|
||||
if(remain > 0){
|
||||
for(y = 0; y < remain; ++y) {
|
||||
srcChannel[y] = srcOffset + y;
|
||||
for(x = 0; x < area; ++x) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y] += UNIT;
|
||||
dst++;
|
||||
}
|
||||
int x;
|
||||
int z;
|
||||
int cur = 0;
|
||||
for (z = 0; z < depth; ++z) {
|
||||
int plane = z / UNIT;
|
||||
const TIN* srcPlane = plane * area * UNIT + src;
|
||||
int offset = z % UNIT;
|
||||
for (x = 0; x < area; ++x) {
|
||||
dst[cur++] = TOUT(srcPlane[UNIT * x + offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
template<typename T, typename U>
|
||||
void MNNSlowCopy(T* dst, const U* src, size_t size) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
dst[i] = (T)src[i];
|
||||
}
|
||||
}
|
||||
|
||||
#endif // Arm82OptFunc_hpp
|
||||
#endif
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
// Created by MNN on 2020/01/08.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#include "backend/arm82/Arm82Pooling.hpp"
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82Pooling.hpp"
|
||||
#include "Arm82Vec.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/Macro.h"
|
||||
|
||||
|
@ -14,6 +16,8 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
|
||||
namespace MNN {
|
||||
|
||||
static void poolingMaxFp16Unit(FLOAT16 *dst, int outputWidth, int outputHeight, const FLOAT16 *src, int inputWidth,
|
||||
|
@ -30,34 +34,16 @@ static void poolingMaxFp16Unit(FLOAT16 *dst, int outputWidth, int outputHeight,
|
|||
|
||||
auto dstCurPtr = dst + (oy * outputWidth + ox) * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t curIn, curOut;
|
||||
curOut = vdupq_n_f16(float16_t(-65504.0));
|
||||
#else
|
||||
// init
|
||||
FLOAT16 curOut[ARMV82_CHANNEL_UNIT];
|
||||
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
||||
curOut[i] = -65504.0;
|
||||
}
|
||||
#endif
|
||||
Vec curIn;
|
||||
Vec curOut(-65504.0);
|
||||
for (int y = kys; y < kye; ++y) {
|
||||
for (int x = kxs; x < kxe; ++x) {
|
||||
const int inOffset = ((srcOriginY + y) * inputWidth + srcOriginX + x) * ARMV82_CHANNEL_UNIT;
|
||||
#ifdef MNN_USE_NEON
|
||||
curIn = vld1q_f16(src + inOffset);
|
||||
curOut = vmaxq_f16(curIn, curOut);
|
||||
#else
|
||||
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
||||
curOut[i] = std::max(curOut[i], src[inOffset + i]);
|
||||
}
|
||||
#endif
|
||||
curIn = Vec::load(src + inOffset);
|
||||
curOut = Vec::max(curIn, curOut);
|
||||
}
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
vst1q_f16(dstCurPtr, curOut);
|
||||
#else
|
||||
memcpy(dstCurPtr, curOut, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
#endif
|
||||
Vec::save(dstCurPtr, curOut);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -77,39 +63,15 @@ static void poolingAvgFp16Unit(FLOAT16 *dst, int outputWidth, int outputHeight,
|
|||
|
||||
auto dstCurPtr = dst + (oy * outputWidth + ox) * ARMV82_CHANNEL_UNIT;
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t curIn, curOut;
|
||||
curOut = vdupq_n_f16(float16_t(0));
|
||||
float16x8_t size = vdupq_n_f16(float16_t(kernelCount));
|
||||
#else
|
||||
// init
|
||||
FLOAT16 curOut[ARMV82_CHANNEL_UNIT];
|
||||
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
||||
curOut[i] = 0;
|
||||
}
|
||||
#endif
|
||||
Vec curOut(0), size(kernelCount);
|
||||
for (int y = kys; y < kye; ++y) {
|
||||
for (int x = kxs; x < kxe; ++x) {
|
||||
const int inOffset = ((srcOriginY + y) * inputWidth + srcOriginX + x) * ARMV82_CHANNEL_UNIT;
|
||||
const auto srcUnit = src + inOffset;
|
||||
#ifdef MNN_USE_NEON
|
||||
curIn = vld1q_f16(srcUnit);
|
||||
curOut = vaddq_f16(curIn, curOut);
|
||||
#else
|
||||
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
||||
curOut[i] = curOut[i] + srcUnit[i];
|
||||
}
|
||||
#endif
|
||||
curOut = curOut + Vec::load(srcUnit);
|
||||
}
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
vst1q_f16(dstCurPtr, vdivq_f16(curOut, size));
|
||||
#else
|
||||
for (int i = 0; i < ARMV82_CHANNEL_UNIT; ++i) {
|
||||
curOut[i] = curOut[i] / kernelCount;
|
||||
}
|
||||
memcpy(dstCurPtr, curOut, sizeof(FLOAT16) * ARMV82_CHANNEL_UNIT);
|
||||
#endif
|
||||
Vec::save(dstCurPtr, curOut / size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -192,11 +154,7 @@ ErrorCode Arm82Pooling::onExecute(const std::vector<Tensor *> &inputs, const std
|
|||
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNumber)
|
||||
mThreadFunction((int)tId, srcOrigin, dstOrigin);
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
MNN_CONCURRENCY_END();
|
||||
#else
|
||||
MNN_CONCURRENCY_END();
|
||||
#endif
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
|
@ -212,4 +170,4 @@ class Arm82PoolingCreator : public Arm82Backend::Arm82Creator {
|
|||
REGISTER_ARM82_OP_CREATOR(OpType_Pooling, Arm82PoolingCreator);
|
||||
|
||||
} // namespace MNN
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
// Created by MNN on 2020/01/08.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Pooling_hpp
|
||||
#define Arm82Pooling_hpp
|
||||
|
||||
#include "MNN_generated.h"
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// Created by MNN on 2020/5/25.
|
||||
// Copyright © 2018 Alibaba. All rights reserved.
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82Raster.hpp"
|
||||
#include "math/Vec.hpp"
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
// Created by MNN on 2020/5/25.
|
||||
// Copyright © 2018 Alibaba. All rights reserved.
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Raster_hpp
|
||||
#define Arm82Raster_hpp
|
||||
#ifdef __aarch64__
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
#include <map>
|
||||
|
@ -35,5 +35,5 @@ private:
|
|||
bool mFast = false;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
#endif /* Arm82Raster_hpp */
|
||||
#endif
|
||||
|
|
|
@ -31,7 +31,7 @@ def generateCPUFile(rootDir):
|
|||
f.write("extern void " + l + '();\n')
|
||||
f.write('\n')
|
||||
f.write('void registerArm82Ops() {\n')
|
||||
f.write("#ifdef __aarch64__\n")
|
||||
f.write("#if defined(__ANDROID__) || defined(__aarch64__)\n")
|
||||
for l in funcNames:
|
||||
f.write(l+'();\n')
|
||||
f.write("#endif\n")
|
||||
|
|
|
@ -5,17 +5,18 @@
|
|||
// Created by MNN on 2020/2/13.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "backend/arm82/Arm82Relu.hpp"
|
||||
#include "Arm82Relu.hpp"
|
||||
#include "MNN_generated.h"
|
||||
#include "backend/arm82/Arm82Backend.hpp"
|
||||
#include "backend/arm82/Arm82OptFunc.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "Arm82OptFunc.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/Macro.h"
|
||||
#include "half.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
@ -32,7 +33,7 @@ static void _MNNArm82PReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FL
|
|||
#ifdef MNN_USE_NEON
|
||||
float16x8_t value = vld1q_f16(src + i * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t mulSlope = vmulq_f16(value, slopeV);
|
||||
float16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
|
||||
vst1q_f16(dst + i * ARMV82_CHANNEL_UNIT, vbslq_f16(lessThanZero, mulSlope, value));
|
||||
#else
|
||||
|
@ -50,52 +51,51 @@ static void _MNNArm82PReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FL
|
|||
}
|
||||
|
||||
static void _MNNArm82LeakyReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 slope, size_t length) {
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t value_0 = vmovq_n_f16(0);
|
||||
float16x8_t slopeV = vmovq_n_f16(slope);
|
||||
#endif
|
||||
auto lC8 = length / ARMV82_CHANNEL_UNIT;
|
||||
auto remain = length % ARMV82_CHANNEL_UNIT;
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t value = vld1q_f16(src + i * ARMV82_CHANNEL_UNIT);
|
||||
for (int i = 0; i < lC8; ++i) {
|
||||
float16x8_t value = vld1q_f16(src);
|
||||
float16x8_t mulSlope = vmulq_f16(value, slopeV);
|
||||
float16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
|
||||
vst1q_f16(dst + i * ARMV82_CHANNEL_UNIT, vbslq_f16(lessThanZero, mulSlope, value));
|
||||
#else
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
int index = i * ARMV82_CHANNEL_UNIT + j;
|
||||
if (src[index] < 0) {
|
||||
dst[index] = src[index] * slope;
|
||||
} else {
|
||||
dst[index] = src[index];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
vst1q_f16(dst, vbslq_f16(lessThanZero, mulSlope, value));
|
||||
src += ARMV82_CHANNEL_UNIT;
|
||||
dst += ARMV82_CHANNEL_UNIT;
|
||||
}
|
||||
if (remain > 0) {
|
||||
float16_t tempSrc[ARMV82_CHANNEL_UNIT];
|
||||
float16_t tempDst[ARMV82_CHANNEL_UNIT];
|
||||
::memcpy(tempSrc, src, remain * sizeof(int16_t));
|
||||
float16x8_t value = vld1q_f16(tempSrc);
|
||||
float16x8_t mulSlope = vmulq_f16(value, slopeV);
|
||||
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
vst1q_f16(tempDst, vbslq_f16(lessThanZero, mulSlope, value));
|
||||
::memcpy(dst, tempDst, remain * sizeof(int16_t));
|
||||
}
|
||||
}
|
||||
|
||||
static void _MNNArm82ReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, size_t length) {
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t value_0 = vmovq_n_f16(0);
|
||||
#endif
|
||||
auto lC8 = length / ARMV82_CHANNEL_UNIT;
|
||||
auto remain = length % ARMV82_CHANNEL_UNIT;
|
||||
for (int i = 0; i < lC8; ++i) {
|
||||
float16x8_t value = vld1q_f16(src);
|
||||
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
#ifdef MNN_USE_NEON
|
||||
float16x8_t value = vld1q_f16(src + i * ARMV82_CHANNEL_UNIT);
|
||||
float16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
|
||||
vst1q_f16(dst + i * ARMV82_CHANNEL_UNIT, vbslq_f16(lessThanZero, value_0, value));
|
||||
#else
|
||||
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
||||
int index = i * ARMV82_CHANNEL_UNIT + j;
|
||||
if (src[index] < 0) {
|
||||
dst[index] = 0;
|
||||
} else {
|
||||
dst[index] = src[index];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
vst1q_f16(dst, vbslq_f16(lessThanZero, value_0, value));
|
||||
dst += ARMV82_CHANNEL_UNIT;
|
||||
src += ARMV82_CHANNEL_UNIT;
|
||||
}
|
||||
if (remain > 0) {
|
||||
float16_t tempSrc[ARMV82_CHANNEL_UNIT];
|
||||
float16_t tempDst[ARMV82_CHANNEL_UNIT];
|
||||
::memcpy(tempSrc, src, remain * sizeof(int16_t));
|
||||
float16x8_t value = vld1q_f16(tempSrc);
|
||||
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
||||
vst1q_f16(tempDst, vbslq_f16(lessThanZero, value_0, value));
|
||||
::memcpy(dst, tempDst, remain * sizeof(int16_t));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,41 +106,37 @@ Arm82Relu::Arm82Relu(Backend *backend, float slope) : Execution(backend) {
|
|||
ErrorCode Arm82Relu::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
const int dimension = input->dimensions();
|
||||
MNN_ASSERT(4 == dimension);
|
||||
const int batch = input->batch();
|
||||
const int channel = input->channel();
|
||||
const int width = input->width();
|
||||
const int height = input->height();
|
||||
const int channelDivUnit = UP_DIV(channel, ARMV82_CHANNEL_UNIT);
|
||||
const int batchAndChannel = batch * channelDivUnit;
|
||||
const int plane = width * height;
|
||||
|
||||
auto size = ARM82TensorElementSizeHelper(input);
|
||||
auto schedule = static_cast<CPUBackend*>(backend())->multiThreadDivide(size);
|
||||
|
||||
const auto src = input->host<FLOAT16>();
|
||||
auto dst = output->host<FLOAT16>();
|
||||
|
||||
if (abs(mSlope) < std::numeric_limits<float>::epsilon()) {
|
||||
// relu
|
||||
mThreadNumbers = static_cast<Arm82Backend *>(backend())->numberThread();
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNumbers)
|
||||
for (int b = (int)tId; b < batchAndChannel; b += mThreadNumbers) {
|
||||
_MNNArm82ReluWithChannel(dst + b * plane * ARMV82_CHANNEL_UNIT,
|
||||
src + b * plane * ARMV82_CHANNEL_UNIT,
|
||||
plane);
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
MNN_CONCURRENCY_BEGIN(tId, schedule.second) {
|
||||
int start = schedule.first * (int)tId;
|
||||
int realSize = schedule.first;
|
||||
if (tId == schedule.second -1 ) {
|
||||
realSize = size - start;
|
||||
}
|
||||
|
||||
_MNNArm82ReluWithChannel(dst + start,
|
||||
src + start, realSize);
|
||||
} MNN_CONCURRENCY_END();
|
||||
} else {
|
||||
// leakyrelu
|
||||
FLOAT16 slopeHalf = half_float::half(mSlope);
|
||||
mThreadNumbers = static_cast<Arm82Backend *>(backend())->numberThread();
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNumbers)
|
||||
for (int b = (int)tId; b < batchAndChannel; b += mThreadNumbers) {
|
||||
_MNNArm82LeakyReluWithChannel(dst + b * plane * ARMV82_CHANNEL_UNIT,
|
||||
src + b * plane * ARMV82_CHANNEL_UNIT,
|
||||
slopeHalf,
|
||||
plane);
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
MNN_CONCURRENCY_BEGIN(tId, schedule.second) {
|
||||
int start = schedule.first * (int)tId;
|
||||
int realSize = schedule.first;
|
||||
if (tId == schedule.second -1 ) {
|
||||
realSize = size - start;
|
||||
}
|
||||
|
||||
_MNNArm82LeakyReluWithChannel(dst + start,
|
||||
src + start, slopeHalf, realSize);
|
||||
} MNN_CONCURRENCY_END();
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
|
@ -154,16 +150,14 @@ Arm82PRelu::Arm82PRelu(Backend *backend, const Op *op) : Execution(backend) {
|
|||
if (!allocRes) {
|
||||
return;
|
||||
}
|
||||
auto slopePtr = mSlope->host<FLOAT16>();
|
||||
MNNQuantizeFP16(slopePtr, param->slope()->data(), slopeLength);
|
||||
auto slopePtr = mSlope->host<int16_t>();
|
||||
MNNQuantizeFP16(param->slope()->data(), slopePtr, slopeLength);
|
||||
}
|
||||
|
||||
ErrorCode Arm82PRelu::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
const auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
||||
const int dimension = input->dimensions();
|
||||
MNN_ASSERT(4 == dimension);
|
||||
const int batch = input->batch();
|
||||
const int channel = input->channel();
|
||||
const int width = input->width();
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
// Created by MNN on 2020/2/13.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#ifdef __aarch64__
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Relu_hpp
|
||||
#define Arm82Relu_hpp
|
||||
|
||||
|
@ -21,7 +22,6 @@ public:
|
|||
|
||||
private:
|
||||
float mSlope = 0.0;
|
||||
int mThreadNumbers;
|
||||
};
|
||||
|
||||
class Arm82PRelu : public Execution {
|
||||
|
|
|
@ -0,0 +1,237 @@
|
|||
//
|
||||
// Arm82Unary.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2018/08/02.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include "Arm82Unary.hpp"
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "core/OpCommonUtils.hpp"
|
||||
#include "core/Concurrency.h"
|
||||
#include "MNN_generated.h"
|
||||
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
Arm82Unary::Arm82Unary(Backend *b, UnaryOpOperation type) : MNN::Execution(b), mType(type) {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
ErrorCode Arm82Unary::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
MNN_ASSERT(1 == outputs.size());
|
||||
auto dtype = inputs[0]->getType();
|
||||
MNN_ASSERT(dtype == halide_type_of<float>() || dtype == halide_type_of<int32_t>());
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
template <typename Func, typename T>
|
||||
static ErrorCode _unaryOp(void* inputPtr, void* outputPtr, int elementSize, Backend* bn) {
|
||||
Func f;
|
||||
auto backend = [bn]() {
|
||||
return bn;
|
||||
};
|
||||
const T *inputData = (T*)inputPtr;
|
||||
T *outputData = (T *)outputPtr;
|
||||
auto numberThread = ((CPUBackend*)bn)->threadNumber();
|
||||
MNN_CONCURRENCY_BEGIN(tId, numberThread) {
|
||||
for (int i=tId; i<elementSize; i+=numberThread) {
|
||||
outputData[i] = f(inputData[i]);
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
class UnarySquare {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
return x * x;
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
return vmulq_f16(x, x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class UnaryRsqrt {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
return 1.f / sqrt(x);
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
return vrsqrteq_f16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class UnarySqrt {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
#if defined(MNN_USE_NEON) && defined(__aarch64__)
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
return vsqrtq_f16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class UnaryNeg {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
return -x;
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
return vnegq_f16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class UnaryAbs {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
return abs(x);
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
return vabsq_f16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class UnaryRecipocal {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
return 1.f / x;
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
return vrecpeq_f16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class UnaryHardSwish {
|
||||
public:
|
||||
static FLOAT16 scalarFunc(const FLOAT16& x) {
|
||||
if (x <= -3) {
|
||||
return 0;
|
||||
} else if (x >= 3) {
|
||||
return x;
|
||||
} else {
|
||||
return x * (x + 3) / 6;
|
||||
}
|
||||
}
|
||||
#ifdef MNN_USE_NEON
|
||||
static float16x8_t vecFunc(const float16x8_t& x) {
|
||||
float16x8_t value_l = vmovq_n_f16(-3);
|
||||
float16x8_t value_h = vmovq_n_f16(3);
|
||||
float16x8_t value_d = vmovq_n_f16(1.f/6);
|
||||
float16x8_t value_z = vmovq_n_f16(0);
|
||||
uint16x8_t right = vcleq_f16(x, value_l);
|
||||
float16x8_t middle = vmulq_f16(vmulq_f16(x, vaddq_f16(x, value_h)), value_d);
|
||||
float16x8_t tmp = vbslq_f16(right, x, middle);
|
||||
uint16x8_t left = vcgtq_f16(x, value_l);
|
||||
return vbslq_f16(left, tmp, value_z);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename Helper>
|
||||
ErrorCode Arm82Unary::onExecuteInternal(Tensor* input, Tensor* output) {
|
||||
const int threadNum = ((Arm82Backend*)backend())->threadNumber();
|
||||
const int count = ARM82TensorElementSizeHelper(output);
|
||||
const FLOAT16* inputData = input->host<FLOAT16>();
|
||||
FLOAT16* outputData = output->host<FLOAT16>();
|
||||
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
|
||||
int realSize = UP_DIV(UP_DIV(count, ARMV82_CHANNEL_UNIT), threadNum) * ARMV82_CHANNEL_UNIT;
|
||||
int startIndex = tId * realSize, endIndex = ALIMIN(startIndex + realSize, count);
|
||||
if (endIndex > startIndex) {
|
||||
int index = startIndex, readSizeUnit = realSize / ARMV82_CHANNEL_UNIT;
|
||||
#ifdef MNN_USE_NEON
|
||||
for (int i = 0; i < readSizeUnit; ++i, index += ARMV82_CHANNEL_UNIT) {
|
||||
float16x8_t in = vld1q_f16(inputData + index);
|
||||
vst1q_f16(outputData + index, Helper::vecFunc(in));
|
||||
}
|
||||
#endif
|
||||
for (; index < endIndex; ++index) {
|
||||
outputData[index] = Helper::scalarFunc(inputData[index]);
|
||||
}
|
||||
}
|
||||
} MNN_CONCURRENCY_END();
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode Arm82Unary::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
ErrorCode code;
|
||||
|
||||
switch (mType) {
|
||||
case UnaryOpOperation_ABS:
|
||||
code = onExecuteInternal<UnaryAbs>(input, output);
|
||||
break;
|
||||
case UnaryOpOperation_SQUARE:
|
||||
code = onExecuteInternal<UnarySquare>(input, output);
|
||||
break;
|
||||
case UnaryOpOperation_RSQRT:
|
||||
code = onExecuteInternal<UnaryRsqrt>(input, output);
|
||||
break;
|
||||
case UnaryOpOperation_NEG:
|
||||
code = onExecuteInternal<UnaryNeg>(input, output);
|
||||
break;
|
||||
#if defined(__aarch64__)
|
||||
case UnaryOpOperation_SQRT:
|
||||
code = onExecuteInternal<UnarySqrt>(input, output);
|
||||
break;
|
||||
#endif
|
||||
case UnaryOpOperation_RECIPROCAL:
|
||||
code = onExecuteInternal<UnaryRecipocal>(input, output);
|
||||
break;
|
||||
case UnaryOpOperation_HARDSWISH:
|
||||
code = onExecuteInternal<UnaryHardSwish>(input, output);
|
||||
break;
|
||||
default:
|
||||
MNN_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
|
||||
return code;
|
||||
}
|
||||
|
||||
class Arm82UnaryCreator : public Arm82Backend::Arm82Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||
const MNN::Op *op, Backend *backend) const override {
|
||||
auto type = op->main_as_UnaryOp()->opType();
|
||||
std::vector<UnaryOpOperation> supportOps = {
|
||||
UnaryOpOperation_ABS, UnaryOpOperation_SQUARE, UnaryOpOperation_RSQRT,
|
||||
UnaryOpOperation_NEG, UnaryOpOperation_SQRT, UnaryOpOperation_RECIPROCAL
|
||||
};
|
||||
if (std::find(supportOps.begin(), supportOps.end(), type) != supportOps.end()) {
|
||||
return new Arm82Unary(backend, type);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_ARM82_OP_CREATOR(OpType_UnaryOp, Arm82UnaryCreator);
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
|
@ -0,0 +1,30 @@
|
|||
//
|
||||
// Arm82Unary.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2018/08/02.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Unary_hpp
|
||||
#define Arm82Unary_hpp
|
||||
|
||||
#include "core/Execution.hpp"
|
||||
#include "MNN_generated.h"
|
||||
|
||||
namespace MNN {
|
||||
class Arm82Unary : public Execution {
|
||||
public:
|
||||
Arm82Unary(Backend *b, UnaryOpOperation type);
|
||||
virtual ~Arm82Unary() = default;
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
template <typename Helper> ErrorCode onExecuteInternal(Tensor*, Tensor*);
|
||||
|
||||
protected:
|
||||
UnaryOpOperation mType;
|
||||
};
|
||||
} // namespace MNN
|
||||
#endif /* Arm82Unary_hpp */
|
||||
#endif
|
|
@ -0,0 +1,117 @@
|
|||
//
|
||||
// Arm82Vec.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/01/31.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82Vec_hpp
|
||||
#define Arm82Vec_hpp
|
||||
|
||||
#include "Arm82Backend.hpp"
|
||||
#include "math/Vec.hpp"
|
||||
|
||||
#ifdef MNN_USE_NEON
|
||||
namespace MNN {
|
||||
namespace Math {
|
||||
template<>
|
||||
struct Vec<FLOAT16, 8> {
|
||||
using VecType = Vec<FLOAT16, 8>;
|
||||
float16x8_t value;
|
||||
Vec() {
|
||||
}
|
||||
Vec(const float v) {
|
||||
value = vdupq_n_f16(v);
|
||||
}
|
||||
Vec(const float16x8_t v) {
|
||||
value = v;
|
||||
}
|
||||
Vec(const VecType& lr) {
|
||||
value = lr.value;
|
||||
}
|
||||
Vec(const VecType&& lr) {
|
||||
value = std::move(lr.value);
|
||||
}
|
||||
float operator[](size_t i) {
|
||||
return value[i];
|
||||
}
|
||||
static VecType load(const FLOAT16* addr) {
|
||||
VecType v = { vld1q_f16(addr) };
|
||||
return v;
|
||||
}
|
||||
static void save(FLOAT16* addr, const VecType& v) {
|
||||
vst1q_f16(addr, v.value);
|
||||
}
|
||||
static VecType max(const VecType& v1, const VecType& v2) {
|
||||
VecType dst = { vmaxq_f16(v1.value, v2.value) };
|
||||
return dst;
|
||||
}
|
||||
static VecType min(const VecType& v1, const VecType& v2) {
|
||||
VecType dst = { vminq_f16(v1.value, v2.value) };
|
||||
return dst;
|
||||
}
|
||||
static void mla(VecType& v1, const VecType& v2, const VecType& v3) {
|
||||
v1.value = vfmaq_f16(v1.value, v2.value, v3.value);
|
||||
}
|
||||
static void mls(VecType& v1, const VecType& v2, const VecType& v3) {
|
||||
v1.value = vfmsq_f16(v1.value, v2.value, v3.value);
|
||||
}
|
||||
VecType operator+(const VecType& lr) {
|
||||
VecType dst = { vaddq_f16(value, lr.value) };
|
||||
return dst;
|
||||
}
|
||||
VecType operator-(const VecType& lr) {
|
||||
VecType dst = { vsubq_f16(value, lr.value) };
|
||||
return dst;
|
||||
}
|
||||
VecType operator*(float lr) {
|
||||
VecType dst = { vmulq_n_f16(value, lr) };
|
||||
return dst;
|
||||
}
|
||||
VecType operator*(const VecType& lr) {
|
||||
VecType dst = { vmulq_f16(value, lr.value) };
|
||||
return dst;
|
||||
}
|
||||
VecType operator/(float lr) {
|
||||
#if defined(__aarch64__)
|
||||
VecType dst = { vdivq_f16(value, vdupq_n_f16(lr)) };
|
||||
#else
|
||||
VecType dst;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
dst.value[i] = value[i] / lr;
|
||||
}
|
||||
#endif
|
||||
return dst;
|
||||
}
|
||||
VecType operator/(const VecType& lr) {
|
||||
#if defined(__aarch64__)
|
||||
VecType dst = { vdivq_f16(value, lr.value) };
|
||||
#else
|
||||
VecType dst;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
dst.value[i] = value[i] / lr.value[i];
|
||||
}
|
||||
#endif
|
||||
return dst;
|
||||
}
|
||||
VecType& operator=(const VecType& lr) {
|
||||
value = lr.value;
|
||||
return *this;
|
||||
}
|
||||
VecType& operator=(const VecType&& lr) {
|
||||
value = std::move(lr.value);
|
||||
return *this;
|
||||
}
|
||||
VecType operator-() {
|
||||
VecType dst = { vnegq_f16(value) };
|
||||
return dst;
|
||||
}
|
||||
};
|
||||
} // namespace Math
|
||||
} // namespace MNN
|
||||
#endif /* MNN_USE_NEON */
|
||||
|
||||
#endif // Arm82Vec_hpp
|
||||
#endif
|
|
@ -0,0 +1,209 @@
|
|||
//
|
||||
// Arm82WinogradOptFunc.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2018/10/08.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#include "Arm82WinogradOptFunc.hpp"
|
||||
#include "Arm82Vec.hpp"
|
||||
#include "Arm82OptFunc.hpp"
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include "core/Macro.h"
|
||||
#include "math/Vec.hpp"
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
|
||||
namespace MNN {
|
||||
|
||||
static void _sourceTransformUnit4x4(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
|
||||
auto m0 = s0 - s2;
|
||||
auto m1 = s1 + s2;
|
||||
auto m2 = s2 - s1;
|
||||
auto m3 = s3 - s1;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
Vec::save(dstStart + 2 * dstStep, m2);
|
||||
Vec::save(dstStart + 3 * dstStep, m3);
|
||||
}
|
||||
static void _destTransformUnit4x2(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
|
||||
auto m0 = s0 + s1 + s2;
|
||||
auto m1 = (s1 - s2) + s3;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
}
|
||||
static void _destTransformUnit4x3(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
|
||||
auto m0 = s0 + s1 + s2;
|
||||
auto m1 = (s1 - s2);
|
||||
auto m2 = (s1 + s2) + s3;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
Vec::save(dstStart + 2 * dstStep, m2);
|
||||
}
|
||||
|
||||
#define LOAD6 \
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep); \
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep); \
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep); \
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep); \
|
||||
Vec s4 = Vec::load(srcBlock + 4 * srcStep); \
|
||||
Vec s5 = Vec::load(srcBlock + 5 * srcStep);
|
||||
|
||||
static void _sourceTransformUnit6x6(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
LOAD6;
|
||||
Vec m0 = s0 * (FLOAT16)4 - s2 * (FLOAT16)5 + s4;
|
||||
|
||||
Vec m1 = (s1 + s2) * (-(FLOAT16)4) + (s3 + s4);
|
||||
Vec m2 = (s1 - s2) * ((FLOAT16)4) + (s4 - s3);
|
||||
|
||||
Vec m3 = s1 * -(FLOAT16)2 - s2 + s3 * (FLOAT16)2 + s4;
|
||||
Vec m4 = s1 * (FLOAT16)2 - s2 - s3 * (FLOAT16)2 + s4;
|
||||
|
||||
Vec m5 = s1 * (FLOAT16)4 - s3 * (FLOAT16)5 + s5;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
Vec::save(dstStart + 2 * dstStep, m2);
|
||||
Vec::save(dstStart + 3 * dstStep, m3);
|
||||
Vec::save(dstStart + 4 * dstStep, m4);
|
||||
Vec::save(dstStart + 5 * dstStep, m5);
|
||||
}
|
||||
|
||||
static void _destTransformUnit6x5(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
Vec s4 = Vec::load(srcBlock + 4 * srcStep);
|
||||
Vec s5 = Vec::load(srcBlock + 5 * srcStep);
|
||||
|
||||
auto m0 = s0 + s1 + s2 + s3 + s4;
|
||||
auto m1 = (s1 - s2) + (s3 - s4) * (FLOAT16)2;
|
||||
auto m2 = (s1 + s2) + (s3 + s4) * (FLOAT16)4;
|
||||
auto m3 = (s1 - s2) + (s3 - s4) * (FLOAT16)8;
|
||||
auto m4 = (s1 + s2) + (s3 + s4) * (FLOAT16)16 + s5;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
Vec::save(dstStart + 2 * dstStep, m2);
|
||||
Vec::save(dstStart + 3 * dstStep, m3);
|
||||
Vec::save(dstStart + 4 * dstStep, m4);
|
||||
}
|
||||
static void _destTransformUnit6x4(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
Vec s4 = Vec::load(srcBlock + 4 * srcStep);
|
||||
Vec s5 = Vec::load(srcBlock + 5 * srcStep);
|
||||
auto v0 = s3 + s4;
|
||||
auto v1 = s3 - s4;
|
||||
auto v2 = s1 + s2;
|
||||
auto v3 = s1 - s2;
|
||||
|
||||
auto m0 = s0 + v2 + v0;
|
||||
auto m1 = v3 + v1 + v1;
|
||||
auto m2 = v2 + v0 * (FLOAT16)4;
|
||||
auto m3 = v3 + v1 * (FLOAT16)8 + s5;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
Vec::save(dstStart + 2 * dstStep, m2);
|
||||
Vec::save(dstStart + 3 * dstStep, m3);
|
||||
}
|
||||
static void _destTransformUnit6x3(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
Vec s4 = Vec::load(srcBlock + 4 * srcStep);
|
||||
Vec s5 = Vec::load(srcBlock + 5 * srcStep);
|
||||
|
||||
auto m0 = s0 + s1 + s2 + s3 + s4;
|
||||
auto m1 = (s1 - s2) + (s3 - s4) * (FLOAT16)2;
|
||||
auto m2 = (s1 + s2) + (s3 + s4) * (FLOAT16)4 + s5;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
Vec::save(dstStart + 2 * dstStep, m2);
|
||||
}
|
||||
static void _destTransformUnit6x2(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep) {
|
||||
Vec s0 = Vec::load(srcBlock + 0 * srcStep);
|
||||
Vec s1 = Vec::load(srcBlock + 1 * srcStep);
|
||||
Vec s2 = Vec::load(srcBlock + 2 * srcStep);
|
||||
Vec s3 = Vec::load(srcBlock + 3 * srcStep);
|
||||
Vec s4 = Vec::load(srcBlock + 4 * srcStep);
|
||||
Vec s5 = Vec::load(srcBlock + 5 * srcStep);
|
||||
|
||||
auto m0 = s0 + s1 + s2 + s3 + s4;
|
||||
auto m1 = (s1 - s2) + (s3 - s4) * (FLOAT16)2 + s5;
|
||||
|
||||
Vec::save(dstStart + 0 * dstStep, m0);
|
||||
Vec::save(dstStart + 1 * dstStep, m1);
|
||||
}
|
||||
|
||||
static Arm82WinogradFunction::TransformFunc gProcUnit6[] = {
|
||||
nullptr, // 0
|
||||
nullptr, // 1
|
||||
_destTransformUnit6x2,
|
||||
_destTransformUnit6x3,
|
||||
_destTransformUnit6x4,
|
||||
_destTransformUnit6x5,
|
||||
};
|
||||
|
||||
|
||||
Arm82WinogradFunction::TransformFunc Arm82WinogradFunction::chooseSourceTransform(int k, int w) {
|
||||
if (6 == k && 6 == w) {
|
||||
return _sourceTransformUnit6x6;
|
||||
}
|
||||
if (4 == k && 4 == w) {
|
||||
return _sourceTransformUnit4x4;
|
||||
}
|
||||
MNN_ASSERT(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Arm82WinogradFunction::TransformFunc Arm82WinogradFunction::chooseDestTransform(int k, int h) {
|
||||
if (6 == k) {
|
||||
if (h <= 1 || h > 5) {
|
||||
return nullptr;
|
||||
}
|
||||
return gProcUnit6[h];
|
||||
}
|
||||
if (2 == h && 4 == k) {
|
||||
return _destTransformUnit4x2;
|
||||
}
|
||||
if (3 == h && 4 == k) {
|
||||
return _destTransformUnit4x3;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int Arm82MNNGetConvTileNumber() {
|
||||
int eP, lP, hP;
|
||||
Arm82MNNGetMatMulPackMode(&eP, &lP, &hP);
|
||||
return eP; // 8
|
||||
}
|
||||
|
||||
} // namespace MNN
|
||||
#endif
|
|
@ -0,0 +1,30 @@
|
|||
//
|
||||
// Arm82WinogradOptFunc.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2018/10/08.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
#if defined(__ANDROID__) || defined(__aarch64__)
|
||||
|
||||
#ifndef Arm82WinogradOptFunc_hpp
|
||||
#define Arm82WinogradOptFunc_hpp
|
||||
|
||||
#include "Arm82Backend.hpp"
|
||||
|
||||
namespace MNN {
|
||||
class Arm82WinogradFunction {
|
||||
public:
|
||||
typedef void (*TransformFunc)(const FLOAT16* srcBlock, FLOAT16* dstStart, size_t srcStep, size_t dstStep);
|
||||
|
||||
/*Use the generator with interp 0.5*/
|
||||
static TransformFunc chooseSourceTransform(int k, int w);
|
||||
static TransformFunc chooseDestTransform(int k, int h);
|
||||
};
|
||||
|
||||
int Arm82MNNGetConvTileNumber();
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* Arm82WinogradOptFunc_hpp */
|
||||
#endif
|
|
@ -1,22 +1,18 @@
|
|||
|
||||
file(GLOB MNN_ARM82_SRCS "${CMAKE_CURRENT_LIST_DIR}/*.cpp")
|
||||
file(GLOB MNN_ARM82_SRCS "${CMAKE_CURRENT_LIST_DIR}/*.cpp" "${CMAKE_CURRENT_LIST_DIR}/compute/*.cpp")
|
||||
|
||||
set(COMPILE_ARM64 OFF)
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR IOS_ARCH STREQUAL "arm64")
|
||||
set(COMPILE_ARM64 ON)
|
||||
endif()
|
||||
|
||||
file(GLOB MNN_ARM82_SRCS_ASM "${CMAKE_CURRENT_LIST_DIR}/asm/arm64/*")
|
||||
|
||||
add_library(
|
||||
MNN_Arm82
|
||||
OBJECT
|
||||
${MNN_ARM82_SRCS}
|
||||
${MNN_ARM82_SRCS_ASM}
|
||||
)
|
||||
|
||||
if(COMPILE_ARM64)
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?")
|
||||
file(GLOB MNN_ARM82_SRCS_ASM "${CMAKE_CURRENT_LIST_DIR}/asm/arm32/*")
|
||||
add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM})
|
||||
target_compile_options(MNN_Arm82 PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64")
|
||||
file(GLOB MNN_ARM82_SRCS_ASM "${CMAKE_CURRENT_LIST_DIR}/asm/arm64/*")
|
||||
add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM})
|
||||
target_compile_options(MNN_Arm82 PRIVATE -march=armv8.2-a+fp16)
|
||||
else()
|
||||
# Building fat binary requires multiple seperate builds and lipo-by-hand under CMake's design
|
||||
endif()
|
||||
|
||||
target_include_directories(MNN_Arm82 PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_include_directories(MNN_Arm82 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/compute/)
|
||||
target_include_directories(MNN_Arm82 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/asm/)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue