mirror of https://github.com/alibaba/MNN.git
MNN:Sync Sync Internal 2.9.0
This commit is contained in:
parent
ba4ecd9792
commit
7cad2ee83f
|
|
@ -58,9 +58,18 @@ option(MNN_ENABLE_COVERAGE "Build with coverage enable" OFF)
|
|||
option(MNN_BUILD_PROTOBUFFER "Build with protobuffer in MNN" ON)
|
||||
option(MNN_BUILD_OPENCV "Build OpenCV api in MNN." OFF)
|
||||
option(MNN_BUILD_LLM "Build llm library based MNN." OFF)
|
||||
option(MNN_BUILD_DIFFUSION "Build diffusion demo based MNN." OFF)
|
||||
option(MNN_INTERNAL "Build with MNN internal features, such as model authentication, metrics logging" OFF)
|
||||
option(MNN_JNI "Build MNN Jni for java to use" OFF)
|
||||
|
||||
IF (OHOS)
|
||||
include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake)
|
||||
export_headers(DIR ${CMAKE_SOURCE_DIR}/include/MNN)
|
||||
IF (MNN_BUILD_OPENCV)
|
||||
export_headers(DIR ${CMAKE_SOURCE_DIR}/tools/cv/include/cv)
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
IF (NOT DEFINED MNN_USE_SPARSE_COMPUTE)
|
||||
set(MNN_USE_SPARSE_COMPUTE ON)
|
||||
ENDIF()
|
||||
|
|
@ -263,12 +272,12 @@ endif()
|
|||
option(MNN_USE_CPP11 "Enable MNN use c++11" ON)
|
||||
if (NOT MSVC)
|
||||
if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
|
||||
elseif(MNN_USE_CPP11)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
else()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x")
|
||||
|
|
@ -493,7 +502,7 @@ endif()
|
|||
IF(MNN_COREML)
|
||||
add_definitions(-DMNN_COREML_ENABLED=1)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/source/backend/coreml/)
|
||||
|
||||
|
||||
IF(MNN_SEP_BUILD)
|
||||
list(APPEND MNN_DEPS MNNCoreML)
|
||||
list(APPEND MNN_EXTRA_DEPENDS MNNCoreML)
|
||||
|
|
@ -631,7 +640,7 @@ ENDIF()
|
|||
|
||||
IF(MNN_BUILD_LLM)
|
||||
# add_definitions(-DMNN_BUILD_LLM)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/llm/CMakeLists.txt)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/transformers/llm/engine/CMakeLists.txt)
|
||||
ENDIF()
|
||||
|
||||
# NPU
|
||||
|
|
@ -755,6 +764,9 @@ list(REMOVE_ITEM MNN_TARGETS MNN)
|
|||
IF(MNN_BUILD_DEMO)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/demo/exec/CMakeLists.txt)
|
||||
ENDIF()
|
||||
IF(MNN_BUILD_DIFFUSION AND MNN_BUILD_OPENCV AND MNN_IMGCODECS)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/transformers/diffusion/CMakeLists.txt)
|
||||
ENDIF()
|
||||
IF(MNN_BUILD_TOOLS)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/tools/cpp/CMakeLists.txt)
|
||||
ENDIF()
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
|
|||
|
||||
auto modelBuffer = revertor->getBuffer();
|
||||
const auto bufferSize = revertor->getBufferSize();
|
||||
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize));
|
||||
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize), MNN::Interpreter::destroy);
|
||||
revertor.reset();
|
||||
net->setSessionMode(MNN::Interpreter::Session_Release);
|
||||
MNN::ScheduleConfig config;
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_OPENCV_BENCH | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` |
|
||||
| MNN_VULKAN_IMAGE | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` |
|
||||
| MNN_LOW_MEMORY | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` |
|
||||
| MNN_SUPPORT_RENDER | 是否支持图形渲染相关算子实现,默认为 `OFF` |
|
||||
| MNN_SUPPORT_RENDER | 是否支持图形渲染相关算子实现,默认为 `OFF` |
|
||||
| MNN_SUPPORT_TRANSFORMER_FUSE | 是否支持Fuse Transformer相关OP实现,默认为 `OFF` |
|
||||
| MNN_BUILD_LLM | 是否构建基于MNN的llm库和demo,默认为`OFF` |
|
||||
| MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo,需要打开MNN_BUILD_OPENCV和MNN_IMGCODECS宏使用 默认为`OFF` |
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@
|
|||
- `winogradExample.out` winograd示例
|
||||
- `fuseTest` 测试 GPU 自定义算子的功能,目前仅支持 Vulkan Buffer 模式
|
||||
- `GpuInterTest.out` 测试 GPU 内存输入的功能,目前仅支持 OpenCL Buffer 模式与 OpenGL texture 模式,编译时许打开 MNN_OPENCL 与 MNN_OPENGL
|
||||
- `LoRA` 将LorA权重添加到模型权重中
|
||||
## Benchmark工具
|
||||
- 相关编译选项
|
||||
- `MNN_BUILD_BENCHMARK` 是否编译Benchmark工具
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ namespace Express {
|
|||
|
||||
void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& config, int numberThread) {
|
||||
std::lock_guard<std::mutex> _l(mMutex);
|
||||
|
||||
if(type == MNN_FORWARD_AUTO) {
|
||||
ScheduleConfig sConfig;
|
||||
sConfig.type = type;
|
||||
|
|
@ -41,10 +42,12 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig&
|
|||
info.numThread = 4;
|
||||
}
|
||||
mAttr->firstType = std::make_pair(type, info.numThread);
|
||||
|
||||
info.user = (BackendConfig*)&config;
|
||||
std::shared_ptr<Runtime> bn(creator->onCreate(info));
|
||||
mRuntimes[mAttr->firstType] = bn;
|
||||
auto firstIter = mRuntimes.find(mAttr->firstType);
|
||||
if (firstIter == mRuntimes.end()) {
|
||||
info.user = (BackendConfig*)&config;
|
||||
std::shared_ptr<Runtime> bn(creator->onCreate(info));
|
||||
mRuntimes[mAttr->firstType] = bn;
|
||||
}
|
||||
} else {
|
||||
auto creator = MNNGetExtraRuntimeCreator(type);
|
||||
if (nullptr == creator) {
|
||||
|
|
@ -56,11 +59,14 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig&
|
|||
Backend::Info info;
|
||||
info.type = type;
|
||||
mAttr->firstType = std::make_pair(type, numberThread);
|
||||
info.mode = Backend::Info::DIRECT;
|
||||
info.numThread = numberThread;
|
||||
info.user = (BackendConfig*)&config;
|
||||
std::shared_ptr<Runtime> bn(creator->onCreate(info));
|
||||
mRuntimes[mAttr->firstType] = bn;
|
||||
auto firstIter = mRuntimes.find(mAttr->firstType);
|
||||
if (firstIter == mRuntimes.end()) {
|
||||
info.mode = Backend::Info::DIRECT;
|
||||
info.numThread = numberThread;
|
||||
info.user = (BackendConfig*)&config;
|
||||
std::shared_ptr<Runtime> bn(creator->onCreate(info));
|
||||
mRuntimes[mAttr->firstType] = bn;
|
||||
}
|
||||
}
|
||||
_refreshRuntime();
|
||||
}
|
||||
|
|
@ -155,6 +161,10 @@ std::shared_ptr<Executor> Executor::newExecutor(MNNForwardType type,
|
|||
const BackendConfig& config,
|
||||
int numberThread) {
|
||||
auto creator = MNNGetExtraRuntimeCreator(type);
|
||||
if(nullptr == creator) {
|
||||
MNN_ERROR("Don't support %d\n", type);
|
||||
return nullptr;
|
||||
}
|
||||
Backend::Info info;
|
||||
info.type = type;
|
||||
info.numThread = numberThread;
|
||||
|
|
|
|||
|
|
@ -98,6 +98,21 @@ static std::vector<std::shared_ptr<BufferStorage>> preRearrangeWeights( // NOLIN
|
|||
}
|
||||
break;
|
||||
}
|
||||
case MNN::OpType_Attention: {
|
||||
exe.reset(backend->onCreate({}, {}, op));
|
||||
if (exe.get() == nullptr) {
|
||||
exe.reset(backupBackend->onCreate({}, {}, op));
|
||||
}
|
||||
if (nullptr == exe) {
|
||||
break;
|
||||
}
|
||||
// The exe can't clone
|
||||
if (!exe->onClone(nullptr, op, nullptr)) {
|
||||
exe = nullptr;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
|
|||
#define STR_IMP(x) #x
|
||||
#define STR(x) STR_IMP(x)
|
||||
#define MNN_VERSION_MAJOR 2
|
||||
#define MNN_VERSION_MINOR 8
|
||||
#define MNN_VERSION_PATCH 4
|
||||
#define MNN_VERSION_MINOR 9
|
||||
#define MNN_VERSION_PATCH 0
|
||||
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
|
||||
#endif /* MNNDefine_h */
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
cmake ../../../ \
|
||||
-DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DOHOS_ARCH="arm64-v8a" \
|
||||
-DOHOS_STL=c++_static \
|
||||
-DMNN_USE_LOGCAT=false \
|
||||
-DMNN_BUILD_BENCHMARK=ON \
|
||||
-DMNN_USE_SSE=OFF \
|
||||
-DMNN_SUPPORT_BF16=OFF \
|
||||
-DMNN_BUILD_TEST=ON \
|
||||
-DOHOS_PLATFORM_LEVEL=9 \
|
||||
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
|
||||
-DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3
|
||||
|
||||
make -j4
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
DIR=yanxing
|
||||
|
||||
make -j16
|
||||
hdc file send ./libMNN.so /data/local/tmp/$DIR/libMNN.so
|
||||
hdc file send ./libMNN_Express.so /data/local/tmp/$DIR/libMNN_Express.so
|
||||
hdc file send ./MNNV2Basic.out /data/local/tmp/$DIR/MNNV2Basic.out
|
||||
hdc file send ./ModuleBasic.out /data/local/tmp/$DIR/ModuleBasic.out
|
||||
# hdc shell "cd /data/local/tmp/$DIR && rm -r output"
|
||||
# hdc shell "cd /data/local/tmp/$DIR && mkdir output"
|
||||
hdc file send ./unitTest.out /data/local/tmp/$DIR/unitTest.out
|
||||
hdc file send ./testModel.out /data/local/tmp/$DIR/testModel.out
|
||||
hdc file send ./testModelWithDescribe.out /data/local/tmp/$DIR/testModelWithDescribe.out
|
||||
hdc file send ./backendTest.out /data/local/tmp/$DIR/backendTest.out
|
||||
hdc file send ./timeProfile.out /data/local/tmp/$DIR/timeProfile.out
|
||||
hdc file send ./run_test.out /data/local/tmp/$DIR/run_test.out
|
||||
|
|
@ -174,7 +174,6 @@
|
|||
489D7A7B2550FDC800AD896A /* MetalUnary.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2A2550FDC800AD896A /* MetalUnary.hpp */; };
|
||||
489D7A7D2550FDC900AD896A /* MetalConvolution.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A2C2550FDC800AD896A /* MetalConvolution.mm */; };
|
||||
489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */; };
|
||||
489D7A7F2550FDC900AD896A /* MetalReLU.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2E2550FDC800AD896A /* MetalReLU.hpp */; };
|
||||
489D7A802550FDC900AD896A /* MetalEltwise.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */; };
|
||||
489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A302550FDC800AD896A /* MetalPooling.hpp */; };
|
||||
489D7A822550FDC900AD896A /* MetalPReLU.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A312550FDC800AD896A /* MetalPReLU.hpp */; };
|
||||
|
|
@ -184,7 +183,6 @@
|
|||
489D7A8A2550FDC900AD896A /* MetalConvolutionDepthwise.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A392550FDC800AD896A /* MetalConvolutionDepthwise.mm */; };
|
||||
489D7A8B2550FDC900AD896A /* MetalConvolutionWinograd.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A3A2550FDC800AD896A /* MetalConvolutionWinograd.hpp */; };
|
||||
489D7A8C2550FDC900AD896A /* MetalDeconvolution.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A3B2550FDC800AD896A /* MetalDeconvolution.mm */; };
|
||||
489D7A8D2550FDC900AD896A /* MetalReLU.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A3C2550FDC800AD896A /* MetalReLU.mm */; };
|
||||
489D7A8E2550FDC900AD896A /* MetalPooling.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A3D2550FDC800AD896A /* MetalPooling.mm */; };
|
||||
489D7A902550FDC900AD896A /* MetalConvolution.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A3F2550FDC800AD896A /* MetalConvolution.hpp */; };
|
||||
489D7A912550FDC900AD896A /* MetalScale.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A402550FDC800AD896A /* MetalScale.mm */; };
|
||||
|
|
@ -746,6 +744,7 @@
|
|||
9558333D29B0947300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558333C29B0947300488807 /* MNNGelu.S */; };
|
||||
9558334729B09A2300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334629B09A2300488807 /* MNNGelu.S */; };
|
||||
9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334A29B09A7B00488807 /* MNNGeluFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
|
||||
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */; };
|
||||
956F52E12AB2D692004B13D9 /* ImageProcessUtils.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */; };
|
||||
956F52E32AB2D6A1004B13D9 /* ImageProcessUtils.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */; };
|
||||
958375352A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S in Sources */ = {isa = PBXBuildFile; fileRef = 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */; };
|
||||
|
|
@ -1005,7 +1004,6 @@
|
|||
489D7A2A2550FDC800AD896A /* MetalUnary.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalUnary.hpp; sourceTree = "<group>"; };
|
||||
489D7A2C2550FDC800AD896A /* MetalConvolution.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolution.mm; sourceTree = "<group>"; };
|
||||
489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MNNMetalContext.mm; sourceTree = "<group>"; };
|
||||
489D7A2E2550FDC800AD896A /* MetalReLU.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalReLU.hpp; sourceTree = "<group>"; };
|
||||
489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalEltwise.hpp; sourceTree = "<group>"; };
|
||||
489D7A302550FDC800AD896A /* MetalPooling.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalPooling.hpp; sourceTree = "<group>"; };
|
||||
489D7A312550FDC800AD896A /* MetalPReLU.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalPReLU.hpp; sourceTree = "<group>"; };
|
||||
|
|
@ -1015,7 +1013,6 @@
|
|||
489D7A392550FDC800AD896A /* MetalConvolutionDepthwise.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolutionDepthwise.mm; sourceTree = "<group>"; };
|
||||
489D7A3A2550FDC800AD896A /* MetalConvolutionWinograd.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalConvolutionWinograd.hpp; sourceTree = "<group>"; };
|
||||
489D7A3B2550FDC800AD896A /* MetalDeconvolution.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalDeconvolution.mm; sourceTree = "<group>"; };
|
||||
489D7A3C2550FDC800AD896A /* MetalReLU.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalReLU.mm; sourceTree = "<group>"; };
|
||||
489D7A3D2550FDC800AD896A /* MetalPooling.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalPooling.mm; sourceTree = "<group>"; };
|
||||
489D7A3F2550FDC800AD896A /* MetalConvolution.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalConvolution.hpp; sourceTree = "<group>"; };
|
||||
489D7A402550FDC800AD896A /* MetalScale.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalScale.mm; sourceTree = "<group>"; };
|
||||
|
|
@ -1587,6 +1584,7 @@
|
|||
9558333C29B0947300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
|
||||
9558334629B09A2300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
|
||||
9558334A29B09A7B00488807 /* MNNGeluFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNGeluFP16.S; path = ../../../arm82/asm/arm64/MNNGeluFP16.S; sourceTree = "<group>"; };
|
||||
9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryLayernorm.cpp; sourceTree = "<group>"; };
|
||||
956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ImageProcessUtils.cpp; sourceTree = "<group>"; };
|
||||
956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ImageProcessUtils.hpp; sourceTree = "<group>"; };
|
||||
958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; path = arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; sourceTree = "<group>"; };
|
||||
|
|
@ -1842,6 +1840,7 @@
|
|||
4819FB2624C139690050BD09 /* GeometryLSTM.cpp */,
|
||||
4819FB2424C139680050BD09 /* GeometryPoolGrad.cpp */,
|
||||
4819FB2A24C139690050BD09 /* GeometryReduce.cpp */,
|
||||
9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */,
|
||||
489404DD24A2FC2B001E456C /* GeometryReverseSequence.cpp */,
|
||||
48FD0349246AA40300456AF5 /* GeometryConvert.cpp */,
|
||||
48FD12BD2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp */,
|
||||
|
|
@ -2150,7 +2149,6 @@
|
|||
489D7A2A2550FDC800AD896A /* MetalUnary.hpp */,
|
||||
489D7A2C2550FDC800AD896A /* MetalConvolution.mm */,
|
||||
489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */,
|
||||
489D7A2E2550FDC800AD896A /* MetalReLU.hpp */,
|
||||
489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */,
|
||||
489D7A302550FDC800AD896A /* MetalPooling.hpp */,
|
||||
489D7A312550FDC800AD896A /* MetalPReLU.hpp */,
|
||||
|
|
@ -2160,7 +2158,6 @@
|
|||
489D7A392550FDC800AD896A /* MetalConvolutionDepthwise.mm */,
|
||||
489D7A3A2550FDC800AD896A /* MetalConvolutionWinograd.hpp */,
|
||||
489D7A3B2550FDC800AD896A /* MetalDeconvolution.mm */,
|
||||
489D7A3C2550FDC800AD896A /* MetalReLU.mm */,
|
||||
489D7A3D2550FDC800AD896A /* MetalPooling.mm */,
|
||||
489D7A3F2550FDC800AD896A /* MetalConvolution.hpp */,
|
||||
489D7A402550FDC800AD896A /* MetalScale.mm */,
|
||||
|
|
@ -3070,7 +3067,6 @@
|
|||
EBECA39924643D320062C7A3 /* Arm82Relu.hpp in Headers */,
|
||||
4838EA7C2611BFE20027232C /* CPUGridSample.hpp in Headers */,
|
||||
92FF03A523AA0B5A00AC97F6 /* DeconvolutionWithStride.hpp in Headers */,
|
||||
489D7A7F2550FDC900AD896A /* MetalReLU.hpp in Headers */,
|
||||
92FF03D123AA0B5A00AC97F6 /* CPUTopKV2.hpp in Headers */,
|
||||
92FF033F23AA0B5A00AC97F6 /* CPUArgMax.hpp in Headers */,
|
||||
92FF034C23AA0B5A00AC97F6 /* CPUSetDiff1D.hpp in Headers */,
|
||||
|
|
@ -3602,7 +3598,6 @@
|
|||
486E1A9924F5078D00C16006 /* CPURandomUniform.cpp in Sources */,
|
||||
92FF02C823AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */,
|
||||
92FF045C23AA0B7100AC97F6 /* ShapeBroadcastTo.cpp in Sources */,
|
||||
489D7A8D2550FDC900AD896A /* MetalReLU.mm in Sources */,
|
||||
48747D49245D9D24000B9709 /* RuntimeFactory.cpp in Sources */,
|
||||
92FF02AE23AA0B5A00AC97F6 /* CPUProposal.cpp in Sources */,
|
||||
92FF042723AA0B7100AC97F6 /* ShapeMatMul.cpp in Sources */,
|
||||
|
|
@ -3627,6 +3622,7 @@
|
|||
92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */,
|
||||
92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */,
|
||||
CECF8C87299CAD9400D3875B /* sds.c in Sources */,
|
||||
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */,
|
||||
4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */,
|
||||
CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */,
|
||||
92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */,
|
||||
|
|
@ -4131,7 +4127,7 @@
|
|||
CODE_SIGN_STYLE = Automatic;
|
||||
DEAD_CODE_STRIPPING = YES;
|
||||
DEFINES_MODULE = YES;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
|
|
@ -4218,7 +4214,7 @@
|
|||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
GCC_ENABLE_CPP_EXCEPTIONS = NO;
|
||||
GCC_ENABLE_CPP_RTTI = NO;
|
||||
HEADER_SEARCH_PATHS = (
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
|
|||
option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
|
||||
option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
|
||||
|
||||
if (OHOS)
|
||||
include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake)
|
||||
endif()
|
||||
|
||||
if(PYMNN_INTERNAL_SERVING)
|
||||
file(GLOB_RECURSE SRC ${CMAKE_CURRENT_LIST_DIR}/src/MNN.cc
|
||||
${CMAKE_CURRENT_LIST_DIR}/src/internal/monitor_service.cc
|
||||
|
|
@ -185,12 +189,20 @@ if(WIN32 OR APPLE OR CMAKE_SYSTEM_NAME MATCHES "^Linux")
|
|||
else()
|
||||
target_include_directories(mnnpybridge PRIVATE ${MNN_DIR}/pymnn/src ${MNN_DIR}/pymnn/android/src/main/c/include)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${MNN_DIR}/pymnn/android/src/main/jniLibs/${ANDROID_ABI})
|
||||
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
|
||||
endif()
|
||||
if(PYMNN_NUMPY_USABLE)
|
||||
target_link_libraries(mnnpybridge PRIVATE numpy_python)
|
||||
if (OHOS)
|
||||
target_link_libraries(mnnpybridge PRIVATE tcpkg::mnn)
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_link_libraries(mnnpybridge PRIVATE tcpkg::alinnpython)
|
||||
endif()
|
||||
export_headers(DIR ${CMAKE_SOURCE_DIR}/pip_package/MNN)
|
||||
else()
|
||||
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
|
||||
endif()
|
||||
if(PYMNN_NUMPY_USABLE)
|
||||
target_link_libraries(mnnpybridge PRIVATE numpy_python)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ struct ExtraT;
|
|||
struct StringVec;
|
||||
struct StringVecT;
|
||||
|
||||
struct AttentionParam;
|
||||
struct AttentionParamT;
|
||||
|
||||
struct FmhaV2Param;
|
||||
struct FmhaV2ParamT;
|
||||
|
||||
|
|
@ -69,6 +72,8 @@ inline const flatbuffers::TypeTable *ExtraTypeTable();
|
|||
|
||||
inline const flatbuffers::TypeTable *StringVecTypeTable();
|
||||
|
||||
inline const flatbuffers::TypeTable *AttentionParamTypeTable();
|
||||
|
||||
inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable();
|
||||
|
||||
inline const flatbuffers::TypeTable *FmhcaParamTypeTable();
|
||||
|
|
@ -261,6 +266,7 @@ enum OpType {
|
|||
OpType_BatchNorm = 267,
|
||||
OpType_ConvTranspose3D = 268,
|
||||
OpType_ZeroGrad = 269,
|
||||
OpType_Attention = 299,
|
||||
OpType_FmhaV2 = 300,
|
||||
OpType_Fmhca = 301,
|
||||
OpType_SeqLen2Spatial = 302,
|
||||
|
|
@ -281,7 +287,7 @@ enum OpType {
|
|||
OpType_MAX = OpType_GridSample
|
||||
};
|
||||
|
||||
inline const OpType (&EnumValuesOpType())[181] {
|
||||
inline const OpType (&EnumValuesOpType())[182] {
|
||||
static const OpType values[] = {
|
||||
OpType_AbsVal,
|
||||
OpType_QuantizedAdd,
|
||||
|
|
@ -448,6 +454,7 @@ inline const OpType (&EnumValuesOpType())[181] {
|
|||
OpType_BatchNorm,
|
||||
OpType_ConvTranspose3D,
|
||||
OpType_ZeroGrad,
|
||||
OpType_Attention,
|
||||
OpType_FmhaV2,
|
||||
OpType_Fmhca,
|
||||
OpType_SeqLen2Spatial,
|
||||
|
|
@ -769,7 +776,7 @@ inline const char * const *EnumNamesOpType() {
|
|||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"Attention",
|
||||
"FmhaV2",
|
||||
"Fmhca",
|
||||
"SeqLen2Spatial",
|
||||
|
|
@ -1185,11 +1192,12 @@ enum OpParameter {
|
|||
OpParameter_GroupNorm = 95,
|
||||
OpParameter_FmhaV2Param = 96,
|
||||
OpParameter_FmhcaParam = 97,
|
||||
OpParameter_AttentionParam = 98,
|
||||
OpParameter_MIN = OpParameter_NONE,
|
||||
OpParameter_MAX = OpParameter_FmhcaParam
|
||||
OpParameter_MAX = OpParameter_AttentionParam
|
||||
};
|
||||
|
||||
inline const OpParameter (&EnumValuesOpParameter())[98] {
|
||||
inline const OpParameter (&EnumValuesOpParameter())[99] {
|
||||
static const OpParameter values[] = {
|
||||
OpParameter_NONE,
|
||||
OpParameter_QuantizedAdd,
|
||||
|
|
@ -1288,7 +1296,8 @@ inline const OpParameter (&EnumValuesOpParameter())[98] {
|
|||
OpParameter_CumSum,
|
||||
OpParameter_GroupNorm,
|
||||
OpParameter_FmhaV2Param,
|
||||
OpParameter_FmhcaParam
|
||||
OpParameter_FmhcaParam,
|
||||
OpParameter_AttentionParam
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
|
@ -1393,13 +1402,14 @@ inline const char * const *EnumNamesOpParameter() {
|
|||
"GroupNorm",
|
||||
"FmhaV2Param",
|
||||
"FmhcaParam",
|
||||
"AttentionParam",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameOpParameter(OpParameter e) {
|
||||
if (e < OpParameter_NONE || e > OpParameter_FmhcaParam) return "";
|
||||
if (e < OpParameter_NONE || e > OpParameter_AttentionParam) return "";
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesOpParameter()[index];
|
||||
}
|
||||
|
|
@ -1796,6 +1806,10 @@ template<> struct OpParameterTraits<FmhcaParam> {
|
|||
static const OpParameter enum_value = OpParameter_FmhcaParam;
|
||||
};
|
||||
|
||||
template<> struct OpParameterTraits<AttentionParam> {
|
||||
static const OpParameter enum_value = OpParameter_AttentionParam;
|
||||
};
|
||||
|
||||
struct OpParameterUnion {
|
||||
OpParameter type;
|
||||
void *value;
|
||||
|
|
@ -2603,6 +2617,14 @@ struct OpParameterUnion {
|
|||
return type == OpParameter_FmhcaParam ?
|
||||
reinterpret_cast<const FmhcaParamT *>(value) : nullptr;
|
||||
}
|
||||
AttentionParamT *AsAttentionParam() {
|
||||
return type == OpParameter_AttentionParam ?
|
||||
reinterpret_cast<AttentionParamT *>(value) : nullptr;
|
||||
}
|
||||
const AttentionParamT *AsAttentionParam() const {
|
||||
return type == OpParameter_AttentionParam ?
|
||||
reinterpret_cast<const AttentionParamT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type);
|
||||
|
|
@ -2900,6 +2922,60 @@ inline flatbuffers::Offset<StringVec> CreateStringVec(
|
|||
|
||||
flatbuffers::Offset<StringVec> CreateStringVec(flatbuffers::FlatBufferBuilder &_fbb, const StringVecT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct AttentionParamT : public flatbuffers::NativeTable {
|
||||
typedef AttentionParam TableType;
|
||||
bool kv_cache;
|
||||
AttentionParamT()
|
||||
: kv_cache(true) {
|
||||
}
|
||||
};
|
||||
|
||||
struct AttentionParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef AttentionParamT NativeTableType;
|
||||
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
|
||||
return AttentionParamTypeTable();
|
||||
}
|
||||
bool kv_cache() const {
|
||||
return GetField<uint8_t>(4, 1) != 0;
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<uint8_t>(verifier, 4) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
AttentionParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(AttentionParamT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<AttentionParam> Pack(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct AttentionParamBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_kv_cache(bool kv_cache) {
|
||||
fbb_.AddElement<uint8_t>(4, static_cast<uint8_t>(kv_cache), 1);
|
||||
}
|
||||
explicit AttentionParamBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
AttentionParamBuilder &operator=(const AttentionParamBuilder &);
|
||||
flatbuffers::Offset<AttentionParam> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<AttentionParam>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<AttentionParam> CreateAttentionParam(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
bool kv_cache = true) {
|
||||
AttentionParamBuilder builder_(_fbb);
|
||||
builder_.add_kv_cache(kv_cache);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<AttentionParam> CreateAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct FmhaV2ParamT : public flatbuffers::NativeTable {
|
||||
typedef FmhaV2Param TableType;
|
||||
int32_t heads;
|
||||
|
|
@ -3784,6 +3860,9 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const FmhcaParam *main_as_FmhcaParam() const {
|
||||
return main_type() == OpParameter_FmhcaParam ? static_cast<const FmhcaParam *>(main()) : nullptr;
|
||||
}
|
||||
const AttentionParam *main_as_AttentionParam() const {
|
||||
return main_type() == OpParameter_AttentionParam ? static_cast<const AttentionParam *>(main()) : nullptr;
|
||||
}
|
||||
const flatbuffers::String *name() const {
|
||||
return GetPointer<const flatbuffers::String *>(10);
|
||||
}
|
||||
|
|
@ -4209,6 +4288,10 @@ template<> inline const FmhcaParam *Op::main_as<FmhcaParam>() const {
|
|||
return main_as_FmhcaParam();
|
||||
}
|
||||
|
||||
template<> inline const AttentionParam *Op::main_as<AttentionParam>() const {
|
||||
return main_as_AttentionParam();
|
||||
}
|
||||
|
||||
struct OpBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
|
|
@ -5006,6 +5089,32 @@ inline flatbuffers::Offset<StringVec> CreateStringVec(flatbuffers::FlatBufferBui
|
|||
_data);
|
||||
}
|
||||
|
||||
inline AttentionParamT *AttentionParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new AttentionParamT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void AttentionParam::UnPackTo(AttentionParamT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
{ auto _e = kv_cache(); _o->kv_cache = _e; };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<AttentionParam> AttentionParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateAttentionParam(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<AttentionParam> CreateAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AttentionParamT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
auto _kv_cache = _o->kv_cache;
|
||||
return MNN::CreateAttentionParam(
|
||||
_fbb,
|
||||
_kv_cache);
|
||||
}
|
||||
|
||||
inline FmhaV2ParamT *FmhaV2Param::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new FmhaV2ParamT();
|
||||
UnPackTo(_o, _resolver);
|
||||
|
|
@ -5902,6 +6011,10 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj,
|
|||
auto ptr = reinterpret_cast<const FmhcaParam *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case OpParameter_AttentionParam: {
|
||||
auto ptr = reinterpret_cast<const AttentionParam *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -6308,6 +6421,10 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f
|
|||
auto ptr = reinterpret_cast<const FmhcaParam *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case OpParameter_AttentionParam: {
|
||||
auto ptr = reinterpret_cast<const AttentionParam *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
@ -6702,6 +6819,10 @@ inline flatbuffers::Offset<void> OpParameterUnion::Pack(flatbuffers::FlatBufferB
|
|||
auto ptr = reinterpret_cast<const FmhcaParamT *>(value);
|
||||
return CreateFmhcaParam(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case OpParameter_AttentionParam: {
|
||||
auto ptr = reinterpret_cast<const AttentionParamT *>(value);
|
||||
return CreateAttentionParam(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -7096,6 +7217,10 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS
|
|||
value = new FmhcaParamT(*reinterpret_cast<FmhcaParamT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case OpParameter_AttentionParam: {
|
||||
value = new AttentionParamT(*reinterpret_cast<AttentionParamT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -7588,6 +7713,11 @@ inline void OpParameterUnion::Reset() {
|
|||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case OpParameter_AttentionParam: {
|
||||
auto ptr = reinterpret_cast<AttentionParamT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
|
@ -7776,12 +7906,13 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
|
|||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 },
|
||||
{ flatbuffers::ET_INT, 0, 0 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
OpTypeTypeTable
|
||||
};
|
||||
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 300, 301, 302, 303, 304, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
|
||||
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
|
||||
static const char * const names[] = {
|
||||
"AbsVal",
|
||||
"QuantizedAdd",
|
||||
|
|
@ -7948,6 +8079,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
|
|||
"BatchNorm",
|
||||
"ConvTranspose3D",
|
||||
"ZeroGrad",
|
||||
"Attention",
|
||||
"FmhaV2",
|
||||
"Fmhca",
|
||||
"SeqLen2Spatial",
|
||||
|
|
@ -7966,7 +8098,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
|
|||
"GridSample"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_ENUM, 181, type_codes, type_refs, values, names
|
||||
flatbuffers::ST_ENUM, 182, type_codes, type_refs, values, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
|
@ -8070,7 +8202,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
|
|||
{ flatbuffers::ET_SEQUENCE, 0, 93 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 94 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 95 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 96 }
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 96 },
|
||||
{ flatbuffers::ET_SEQUENCE, 0, 97 }
|
||||
};
|
||||
static const flatbuffers::TypeFunction type_refs[] = {
|
||||
QuantizedAddTypeTable,
|
||||
|
|
@ -8169,7 +8302,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
|
|||
CumSumTypeTable,
|
||||
GroupNormTypeTable,
|
||||
FmhaV2ParamTypeTable,
|
||||
FmhcaParamTypeTable
|
||||
FmhcaParamTypeTable,
|
||||
AttentionParamTypeTable
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"NONE",
|
||||
|
|
@ -8269,10 +8403,11 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
|
|||
"CumSum",
|
||||
"GroupNorm",
|
||||
"FmhaV2Param",
|
||||
"FmhcaParam"
|
||||
"FmhcaParam",
|
||||
"AttentionParam"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_UNION, 98, type_codes, type_refs, nullptr, names
|
||||
flatbuffers::ST_UNION, 99, type_codes, type_refs, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
|
@ -8376,6 +8511,19 @@ inline const flatbuffers::TypeTable *StringVecTypeTable() {
|
|||
return &tt;
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *AttentionParamTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_BOOL, 0, -1 }
|
||||
};
|
||||
static const char * const names[] = {
|
||||
"kv_cache"
|
||||
};
|
||||
static const flatbuffers::TypeTable tt = {
|
||||
flatbuffers::ST_TABLE, 1, type_codes, nullptr, nullptr, names
|
||||
};
|
||||
return &tt;
|
||||
}
|
||||
|
||||
inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable() {
|
||||
static const flatbuffers::TypeCode type_codes[] = {
|
||||
{ flatbuffers::ET_INT, 0, -1 }
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ enum OpType : int {
|
|||
ZeroGrad,
|
||||
|
||||
// User define op
|
||||
Attention = 299,
|
||||
FmhaV2 = 300,
|
||||
Fmhca = 301,
|
||||
SeqLen2Spatial = 302,
|
||||
|
|
@ -218,7 +219,7 @@ table Extra {
|
|||
engine: string;
|
||||
info: [byte];
|
||||
attr:[Attribute];
|
||||
// The Extra Op can be vectorized for execution
|
||||
// The Extra Op can be vectorized for execution
|
||||
vector: bool;
|
||||
}
|
||||
|
||||
|
|
@ -226,10 +227,14 @@ table StringVec {
|
|||
data: [string];
|
||||
}
|
||||
|
||||
table AttentionParam {
|
||||
kv_cache: bool = true;
|
||||
}
|
||||
|
||||
table FmhaV2Param {
|
||||
heads: int;
|
||||
}
|
||||
|
||||
|
||||
table FmhcaParam {
|
||||
heads: int;
|
||||
}
|
||||
|
|
@ -408,7 +413,8 @@ union OpParameter {
|
|||
CumSum,
|
||||
GroupNorm,
|
||||
FmhaV2Param,
|
||||
FmhcaParam
|
||||
FmhcaParam,
|
||||
AttentionParam
|
||||
}
|
||||
|
||||
table Op {
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ asm_function MNNGemmHybridInt4FP16_sdot
|
|||
// int32_t useInt8;
|
||||
//};
|
||||
|
||||
//void MNNGemmHybridInt4_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
//void MNNGemmHybridInt4_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
|
||||
|
||||
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
|
||||
|
|
@ -79,8 +79,6 @@ lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_
|
|||
ld1 {v6.16b, v7.16b}, [x14]
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
TILE_4:
|
||||
cmp x6, #4
|
||||
blt TILE_1
|
||||
|
|
@ -120,22 +118,14 @@ LoopSz_TILE_4:
|
|||
// src : 2 x [2 x 8] : v4-5
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 2 x 4 x [4] : v16-23
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v4.16b, v5.16b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v9.16b, v0.16b, #4
|
||||
ushr v12.16b, v1.16b, #4
|
||||
and v8.16b, v0.16b, v14.16b
|
||||
and v13.16b, v1.16b, v14.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v12.16b, v12.16b, v15.16b
|
||||
sub v13.16b, v13.16b, v15.16b
|
||||
zip1 v0.16b, v9.16b, v8.16b
|
||||
zip2 v1.16b, v9.16b, v8.16b
|
||||
zip1 v2.16b, v12.16b, v13.16b
|
||||
zip2 v3.16b, v12.16b, v13.16b
|
||||
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v14.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v14.16b
|
||||
|
||||
mov v10.d[0], v4.d[1]
|
||||
mov v10.d[1], v4.d[0]
|
||||
mov v11.d[1], v5.d[0]
|
||||
|
|
@ -244,21 +234,13 @@ LoopSz_TILE_1:
|
|||
// src : 1 x [1 x 8] : v4
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 1 x 4 x [2] : v16-v19
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v4.8b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v16.16b, v0.16b, #4
|
||||
ushr v17.16b, v1.16b, #4
|
||||
and v18.16b, v0.16b, v14.16b
|
||||
and v19.16b, v1.16b, v14.16b
|
||||
sub v16.16b, v16.16b, v15.16b
|
||||
sub v18.16b, v18.16b, v15.16b
|
||||
sub v17.16b, v17.16b, v15.16b
|
||||
sub v19.16b, v19.16b, v15.16b
|
||||
zip1 v0.16b, v16.16b, v18.16b
|
||||
zip2 v1.16b, v16.16b, v18.16b
|
||||
zip1 v2.16b, v17.16b, v19.16b
|
||||
zip2 v3.16b, v17.16b, v19.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v14.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v14.16b
|
||||
|
||||
mov v29.d[0], v4.d[1]
|
||||
mov v29.d[1], v4.d[0]
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ asm_function MNNGemmHybridInt4FP16_smmla
|
|||
// int32_t useInt8;
|
||||
//};
|
||||
|
||||
//void MNNGemmHybridInt4_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
//void MNNGemmHybridInt4_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
|
||||
|
||||
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
|
||||
|
|
@ -135,22 +135,14 @@ LoopDz_TILE_8:
|
|||
LoopSz_TILE_8:
|
||||
// src : 2 x [2 x 8] : v4-5
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 2 x 4 x [4] : v16-23
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
// dst : 2 x 4 x [4] : v16-31
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v10.16b
|
||||
sub v8.16b, v8.16b, v11.16b
|
||||
sub v9.16b, v9.16b, v11.16b
|
||||
ushr v12.16b, v1.16b, #4
|
||||
and v13.16b, v1.16b, v10.16b
|
||||
sub v12.16b, v12.16b, v11.16b
|
||||
sub v13.16b, v13.16b, v11.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
zip2 v1.16b, v8.16b, v9.16b
|
||||
zip1 v2.16b, v12.16b, v13.16b
|
||||
zip2 v3.16b, v12.16b, v13.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v10.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v10.16b
|
||||
|
||||
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1
|
||||
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3
|
||||
|
|
@ -270,21 +262,13 @@ LoopSz_TILE_4:
|
|||
// src : 2 x [2 x 8] : v4-5
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 2 x 4 x [4] : v16-23
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v4.16b, v5.16b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v10.16b
|
||||
sub v8.16b, v8.16b, v11.16b
|
||||
sub v9.16b, v9.16b, v11.16b
|
||||
ushr v12.16b, v1.16b, #4
|
||||
and v13.16b, v1.16b, v10.16b
|
||||
sub v12.16b, v12.16b, v11.16b
|
||||
sub v13.16b, v13.16b, v11.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
zip2 v1.16b, v8.16b, v9.16b
|
||||
zip1 v2.16b, v12.16b, v13.16b
|
||||
zip2 v3.16b, v12.16b, v13.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v10.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v10.16b
|
||||
|
||||
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
|
||||
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
|
||||
|
|
@ -313,7 +297,7 @@ LoopSzEnd_TILE_4:
|
|||
trn2 v5.2d, v8.2d, v8.2d // alpha: oc 2,3,2,3
|
||||
trn1 v6.2d, v9.2d, v9.2d // alpha: oc 4,5,4,5
|
||||
trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7
|
||||
|
||||
|
||||
MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7
|
||||
MulScale_New v20, v21, v22, v23, v1, v4, v5, v6, v7
|
||||
Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2)
|
||||
|
|
@ -372,21 +356,13 @@ LoopSz_TILE_2:
|
|||
// src : 1 x [2 x 8] : v4
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 1 x 4 x [4] : v16-19
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v4.16b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v10.16b
|
||||
sub v8.16b, v8.16b, v11.16b
|
||||
sub v9.16b, v9.16b, v11.16b
|
||||
ushr v12.16b, v1.16b, #4
|
||||
and v13.16b, v1.16b, v10.16b
|
||||
sub v12.16b, v12.16b, v11.16b
|
||||
sub v13.16b, v13.16b, v11.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
zip2 v1.16b, v8.16b, v9.16b
|
||||
zip1 v2.16b, v12.16b, v13.16b
|
||||
zip2 v3.16b, v12.16b, v13.16b
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v10.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v10.16b
|
||||
|
||||
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
|
||||
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
|
||||
|
|
@ -409,10 +385,10 @@ LoopSzEnd_TILE_2:
|
|||
trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7
|
||||
MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7
|
||||
Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2)
|
||||
|
||||
|
||||
uzp1 v4.4s, v0.4s, v1.4s
|
||||
uzp2 v5.4s, v0.4s, v1.4s
|
||||
|
||||
|
||||
Tile2Dequant:
|
||||
ld1 {v16.8h}, [x20], #16 // zero
|
||||
ld1 {v17.8h}, [x21], #16 // bias
|
||||
|
|
@ -463,22 +439,14 @@ LoopSz_TILE_1:
|
|||
// dst : 1 x 4 x [2] : v16-v19
|
||||
prfm pldl1keep, [x25, #64] // 预取下一次权重数据
|
||||
prfm pldl1keep, [x24, x15] // 预取下一次源数据
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v4.8b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v10.16b
|
||||
sub v8.16b, v8.16b, v11.16b
|
||||
sub v9.16b, v9.16b, v11.16b
|
||||
ushr v12.16b, v1.16b, #4
|
||||
and v13.16b, v1.16b, v10.16b
|
||||
sub v12.16b, v12.16b, v11.16b
|
||||
sub v13.16b, v13.16b, v11.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
zip2 v1.16b, v8.16b, v9.16b
|
||||
zip1 v2.16b, v12.16b, v13.16b
|
||||
zip2 v3.16b, v12.16b, v13.16b
|
||||
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v10.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v10.16b
|
||||
|
||||
.inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
|
||||
.inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
|
||||
.inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b
|
||||
|
|
@ -505,7 +473,7 @@ LoopSzEnd_TILE_1:
|
|||
fcvtn v17.4h, v20.4s
|
||||
fcvtn2 v17.8h, v21.4s
|
||||
Tile1Dequant:
|
||||
|
||||
|
||||
ld1 {v1.8h}, [x20], #16 // zero
|
||||
ld1 {v2.8h}, [x21], #16 // bias
|
||||
ld1 {v3.h}[0], [x22] // sums
|
||||
|
|
@ -535,4 +503,4 @@ ldp d12, d13, [sp, #(16 * 1)]
|
|||
ldp d14, d15, [sp], #(16 * 9)
|
||||
ret
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,462 @@
|
|||
//
|
||||
// CPUAttention.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2024/03/19.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
|
||||
#include <limits>
|
||||
#include "CPUAttention.hpp"
|
||||
#include "CPUBackend.hpp"
|
||||
#include "compute/CommonOptFunction.h"
|
||||
#include "core/Macro.h"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/BufferAllocator.hpp"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "core/OpCommonUtils.hpp"
|
||||
|
||||
#if defined (__aarch64__)
|
||||
#define FLOAT16_T __fp16
|
||||
#else
|
||||
#define FLOAT16_T float
|
||||
#endif
|
||||
|
||||
|
||||
namespace MNN {
|
||||
|
||||
template <typename T>
|
||||
static void prefill_pack(Tensor* query, Tensor* key, Tensor* value, char* query_ptr, char* key_ptr, char* value_ptr,
|
||||
int mMaxLength, int mNumHead, int mHeadDim, int mValueH,
|
||||
int eP, int hP, int query_e, int key_h, int seq_len, int h) {
|
||||
auto query_src = query->host<T>();
|
||||
auto key_src = key->host<T>();
|
||||
auto value_src = value->host<T>();
|
||||
auto query_dst = reinterpret_cast<T*>(query_ptr);
|
||||
auto key_dst = reinterpret_cast<T*>(key_ptr);
|
||||
auto value_dst = reinterpret_cast<T*>(value_ptr);
|
||||
// transpose query: [seq_len, num_head, head_dim] -> numhead, [seq_len/eP, head_dim, eP]
|
||||
for (int i = 0; i < query_e; i++) {
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
for (int k = 0; k < eP; k++) {
|
||||
int s = i * eP + k;
|
||||
if (s < seq_len) {
|
||||
query_dst[i * mHeadDim * eP + j * eP + k] = query_src[s * mNumHead * mHeadDim + h * mHeadDim + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// transpose key: [seq_len, num_head, head_dim] -> numhead, [seq_len/hP, head_dim, hP]
|
||||
for (int i = 0; i < key_h; i++) {
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
for (int k = 0; k < hP; k++) {
|
||||
int s = i * hP + k;
|
||||
if (s < seq_len) {
|
||||
key_dst[i * mHeadDim * hP + j * hP + k] = key_src[s * mNumHead * mHeadDim + h * mHeadDim + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// transpose value: [seq_len, num_head, head_dim] -> numhead, [head_dim/hP, seq_len, hP]
|
||||
for (int i = 0; i < mValueH; i++) {
|
||||
for (int j = 0; j < seq_len; j++) {
|
||||
for (int k = 0; k < hP; k++) {
|
||||
int hd = i * hP + k;
|
||||
if (hd < mHeadDim) {
|
||||
value_dst[i * mMaxLength * hP + j * hP + k] = value_src[j * mNumHead * mHeadDim + h * mHeadDim + hd];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void decode_pack(Tensor* query, Tensor* key, Tensor* value, char* query_ptr, char* key_ptr, char* value_ptr,
|
||||
int mMaxLength, int mPastLength, int mHeadDim, int mValueH, int eP, int hP, int h) {
|
||||
auto query_src = query->host<T>();
|
||||
auto key_src = key->host<T>();
|
||||
auto value_src = value->host<T>();
|
||||
auto query_dst = reinterpret_cast<T*>(query_ptr);
|
||||
auto key_dst = reinterpret_cast<T*>(key_ptr);
|
||||
auto value_dst = reinterpret_cast<T*>(value_ptr);
|
||||
for (int i = 0; i < mHeadDim; i++) {
|
||||
query_dst[i * eP] = query_src[h * mHeadDim + i];
|
||||
}
|
||||
// transpose key: [1, num_head, head_dim] -> numhead, [kv_seq_len/hP, head_dim, hP]
|
||||
int outside_offset = UP_DIV(mPastLength, hP);
|
||||
int inside_offset = mPastLength % hP;
|
||||
for (int i = 0; i < mHeadDim; i++) {
|
||||
key_dst[(outside_offset - (inside_offset != 0)) * mHeadDim * hP + i * hP + inside_offset] = key_src[h * mHeadDim + i];
|
||||
}
|
||||
// transpose value: [1, num_head, head_dim] -> numhead, [head_dim/hP, kv_seq_len, hP]
|
||||
for (int i = 0; i < mValueH; i++) {
|
||||
for (int j = 0; j < hP; j++) {
|
||||
value_dst[i * mMaxLength * hP + mPastLength * hP + j] = value_src[h * mHeadDim + i * hP + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void prefill_unpack(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) {
|
||||
auto src_ptr = reinterpret_cast<T*>(pack_qkv);
|
||||
auto dst_ptr = reinterpret_cast<T*>(unpack_qkv);
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
int a = j / unit;
|
||||
int b = j % unit;
|
||||
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void prefill_softmax(int* mask_ptr, float* mask_qk, float* softmax_qk, char* unpack_qk, char* pack_qk,
|
||||
float mScale, int eP, int query_e, int seq_len, float min_val) {
|
||||
T* qk_src = reinterpret_cast<T*>(unpack_qk);
|
||||
T* qk_dst = reinterpret_cast<T*>(pack_qk);
|
||||
for (int i = 0; i < seq_len * seq_len; i++) {
|
||||
if (mask_ptr[i]) {
|
||||
mask_qk[i] = qk_src[i] * mScale;
|
||||
} else {
|
||||
mask_qk[i] = min_val;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
MNNSoftmax(softmax_qk + i * seq_len, mask_qk + i * seq_len, seq_len);
|
||||
}
|
||||
for (int i = 0; i < query_e; i++) {
|
||||
for (int j = 0; j < seq_len; j++) {
|
||||
for (int k = 0; k < eP; k++) {
|
||||
int s = i * eP + k;
|
||||
if (s < seq_len) {
|
||||
qk_dst[i * seq_len * eP + j * eP + k] = softmax_qk[s * seq_len + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void decode_softmax(float* mask_qk, float* softmax_qk, char* unpack_qk, char* pack_qk,
|
||||
float mScale, int eP, int kv_seq_len) {
|
||||
T* qk_src = reinterpret_cast<T*>(unpack_qk);
|
||||
T* qk_dst = reinterpret_cast<T*>(pack_qk);
|
||||
for (int i = 0; i < kv_seq_len; i++) {
|
||||
mask_qk[i] = qk_src[i] * mScale;
|
||||
}
|
||||
// softmax
|
||||
MNNSoftmax(softmax_qk, mask_qk, kv_seq_len);
|
||||
// pack qk
|
||||
for (int i = 0; i < kv_seq_len; i++) {
|
||||
qk_dst[i * eP] = softmax_qk[i];
|
||||
}
|
||||
}
|
||||
|
||||
void CPUAttentionImpl::allocKVCache() {
|
||||
if (!mKVCache || mPastLength < mMaxLength) {
|
||||
return;
|
||||
}
|
||||
mMaxLength = mPastLength + mExpandChunk;
|
||||
// past_key: [1, numhead, headdim, maxlen] -> numhead, [headdim, maxlen] -> pack_b -> numhead, [maxlen/hP, head_dim, hP]
|
||||
mPastKey.reset(Tensor::createDevice<float>({mNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
|
||||
// past_value: [1, numhead, maxlen, headdim] -> numhead, [maxlen, headdim] -> pack_b -> numhead, [head_dim/hP, max_len, hP]
|
||||
mPastValue.reset(Tensor::createDevice<float>({mNumHead, mValueH, mMaxLength, hP}));
|
||||
backend()->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
|
||||
backend()->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
|
||||
}
|
||||
|
||||
void CPUAttentionImpl::reallocKVCache() {
|
||||
if (!mKVCache || mPastLength < mMaxLength) {
|
||||
return;
|
||||
}
|
||||
mMaxLength = mPastLength + mExpandChunk;
|
||||
// past_key: [1, numhead, headdim, maxlen] -> numhead, [headdim, maxlen] -> pack_b -> numhead, [maxlen/hP, head_dim, hP]
|
||||
auto new_key = Tensor::createDevice<float>({mNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
|
||||
// past_value: [1, numhead, maxlen, headdim] -> numhead, [maxlen, headdim] -> pack_b -> numhead, [head_dim/hP, max_len, hP]
|
||||
auto new_value = Tensor::createDevice<float>({mNumHead, mValueH, mMaxLength, hP});
|
||||
backend()->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
backend()->onAcquireBuffer(new_value, Backend::STATIC);
|
||||
// copy
|
||||
for (int h = 0; h < mNumHead; h++) {
|
||||
::memset(new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes, 0, UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes);
|
||||
::memset(new_value->host<char>() + h * mValueH * mMaxLength * hP * bytes, 0, mValueH * mMaxLength * hP * bytes);
|
||||
::memcpy(new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes,
|
||||
mPastKey->host<char>() + h * UP_DIV(mPastLength, hP) * mHeadDim * hP * bytes,
|
||||
UP_DIV(mPastLength, hP) * mHeadDim * hP * bytes);
|
||||
for (int i = 0; i < mValueH; i++) {
|
||||
::memcpy(new_value->host<char>() + (h * mValueH + i) * mMaxLength * hP * bytes,
|
||||
mPastValue->host<char>() + (h * mValueH + i) * mPastLength * hP * bytes,
|
||||
mPastLength * hP * bytes);
|
||||
}
|
||||
}
|
||||
mPastKey.reset(new_key);
|
||||
mPastValue.reset(new_value);
|
||||
mTempQK.reset(Tensor::createDevice<float>({mThreadNum, eP + 2, mMaxLength}));
|
||||
backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC);
|
||||
}
|
||||
|
||||
ErrorCode CPUAttentionImpl::onResize(Backend* _backend, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
mBackend = _backend;
|
||||
auto core = static_cast<CPUBackend *>(backend())->functions();
|
||||
int unit = core->pack;
|
||||
bytes = core->bytes;
|
||||
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
|
||||
|
||||
auto query = inputs[0];
|
||||
auto key = inputs[1];
|
||||
auto value = inputs[2];
|
||||
auto mask = inputs[3];
|
||||
auto shape = query->shape();
|
||||
int seq_len = shape[1];
|
||||
mThreadNum = ((CPUBackend *)backend())->threadNumber();
|
||||
mIsDecode = seq_len == 1;
|
||||
if (mPastLength == 0 || seq_len > 1) {
|
||||
mPastLength = seq_len;
|
||||
}
|
||||
mNumHead = shape[2];
|
||||
mHeadDim = shape[3];
|
||||
mScale = 1.0 / sqrt(mHeadDim);
|
||||
mValueH = UP_DIV(mHeadDim, hP);
|
||||
int query_e = UP_DIV(seq_len, eP);
|
||||
int key_h = UP_DIV(seq_len, hP);
|
||||
// mPastLength = 10;
|
||||
// alloc kv cache
|
||||
allocKVCache();
|
||||
|
||||
int tileCount = UP_DIV(mNumHead, mThreadNum);
|
||||
|
||||
// temp_query
|
||||
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, query_e, mHeadDim, eP}));
|
||||
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
|
||||
if (mIsDecode) {
|
||||
mTempQK.reset(Tensor::createDevice<float>({mThreadNum, eP + 2, mMaxLength}));
|
||||
backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
|
||||
} else {
|
||||
mTempQK.reset(Tensor::createDevice<float>({mThreadNum, 4, seq_len, seq_len}));
|
||||
backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
|
||||
}
|
||||
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode CPUAttentionImpl::onExecute(Backend* _backend, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
auto core = static_cast<CPUBackend *>(backend())->functions();
|
||||
int unit = core->pack;
|
||||
bytes = core->bytes;
|
||||
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
|
||||
mBackend = _backend;
|
||||
auto matmulUnit = core->MNNPackedMatMul;
|
||||
auto matmulRemain = core->MNNPackedMatMulRemain;
|
||||
|
||||
auto query = inputs[0];
|
||||
auto key = inputs[1];
|
||||
auto value = inputs[2];
|
||||
auto mask = inputs[3];
|
||||
auto shape = query->shape();
|
||||
int seq_len = shape[1];
|
||||
mThreadNum = ((CPUBackend *)backend())->threadNumber();
|
||||
mIsDecode = seq_len == 1;
|
||||
if (mPastLength == 0 || seq_len > 1) {
|
||||
mPastLength = seq_len;
|
||||
}
|
||||
mNumHead = shape[2];
|
||||
mHeadDim = shape[3];
|
||||
mScale = 1.0 / sqrt(mHeadDim);
|
||||
mValueH = UP_DIV(mHeadDim, hP);
|
||||
int query_e = UP_DIV(seq_len, eP);
|
||||
int key_h = UP_DIV(seq_len, hP);
|
||||
// mPastLength = 10;
|
||||
|
||||
int tileCount = UP_DIV(mNumHead, mThreadNum);
|
||||
|
||||
// try calloc kv cache
|
||||
mPrefill = [=](int tId){
|
||||
auto pack_q = mPackQ->host<char>() + tId * query_e * mHeadDim * eP * bytes;
|
||||
auto pack_qk = mTempQK->host<char>() + tId * 4 * seq_len * seq_len * bytes;
|
||||
auto unpack_qk = pack_qk + seq_len * seq_len * 2 * bytes;
|
||||
auto mask_qk = reinterpret_cast<float*>(pack_qk);
|
||||
auto softmax_qk = reinterpret_cast<float*>(unpack_qk);
|
||||
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
|
||||
|
||||
int head_index = tId * tileCount;
|
||||
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
|
||||
// pack for matmul
|
||||
auto key_dst = mPastKey->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes;
|
||||
auto value_dst = mPastValue->host<char>() + h * mValueH * mMaxLength * hP * bytes;
|
||||
if (bytes == 2) {
|
||||
prefill_pack<int16_t>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mNumHead, mHeadDim, mValueH, eP, hP, query_e, key_h, seq_len, h);
|
||||
} else {
|
||||
prefill_pack<float>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mNumHead, mHeadDim, mValueH, eP, hP, query_e, key_h, seq_len, h);
|
||||
}
|
||||
// query @ key
|
||||
int loop_e = seq_len / eP;
|
||||
int remain = seq_len % eP;
|
||||
for (int i = 0 ; i < loop_e; i++) {
|
||||
size_t shapeParameters[6];
|
||||
size_t* parameters = shapeParameters;
|
||||
parameters[0] = eP * bytes;
|
||||
parameters[1] = mHeadDim;
|
||||
parameters[2] = seq_len;
|
||||
parameters[3] = seq_len * unit * bytes;
|
||||
parameters[4] = 0;
|
||||
parameters[5] = 0;
|
||||
matmulUnit((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_dst, parameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
{
|
||||
size_t shapeParameters[6];
|
||||
size_t* parameters = shapeParameters;
|
||||
parameters[0] = eP * bytes;
|
||||
parameters[1] = mHeadDim;
|
||||
parameters[2] = seq_len;
|
||||
parameters[3] = seq_len * unit * bytes;
|
||||
parameters[4] = 0;
|
||||
parameters[5] = 0;
|
||||
matmulRemain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_dst, remain, parameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
int area_offset[1] {seq_len};
|
||||
core->MNNUnpackCUnitTranspose((float*)unpack_qk, (float*)pack_qk, seq_len, seq_len, area_offset);
|
||||
// div scale and mask
|
||||
auto mask_ptr = mask->host<int>();
|
||||
if (bytes == 2) {
|
||||
prefill_softmax<FLOAT16_T>(mask_ptr, mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, query_e, seq_len, -65504.0);
|
||||
} else {
|
||||
prefill_softmax<float>(mask_ptr, mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, query_e, seq_len, std::numeric_limits<float>::lowest());
|
||||
}
|
||||
// qk @ v
|
||||
for (int i = 0 ; i < loop_e; i++) {
|
||||
size_t shapeParameters[6];
|
||||
size_t* parameters = shapeParameters;
|
||||
parameters[0] = eP * bytes;
|
||||
parameters[1] = seq_len;
|
||||
parameters[2] = mHeadDim;
|
||||
parameters[3] = seq_len * unit * bytes;
|
||||
parameters[4] = 0;
|
||||
parameters[5] = (mMaxLength - seq_len) * hP * bytes;
|
||||
matmulUnit((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(pack_qk + (i * seq_len * eP) * bytes), (float*)value_dst, parameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
{
|
||||
size_t shapeParameters[6];
|
||||
size_t* parameters = shapeParameters;
|
||||
parameters[0] = eP * bytes;
|
||||
parameters[1] = seq_len;
|
||||
parameters[2] = mHeadDim;
|
||||
parameters[3] = seq_len * unit * bytes;
|
||||
parameters[4] = 0;
|
||||
parameters[5] = (mMaxLength - seq_len) * hP * bytes;
|
||||
matmulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(pack_qk + (loop_e * seq_len * eP) * bytes), (float*)value_dst, remain, parameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
// transpose: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
|
||||
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
|
||||
if (bytes == 2) {
|
||||
prefill_unpack<int16_t>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
|
||||
} else {
|
||||
prefill_unpack<float>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
mDecode = [=](int tId) {
|
||||
int kv_seq_len = mPastLength + 1;
|
||||
auto pack_q = mPackQ->host<char>() + tId * mHeadDim * eP * bytes;
|
||||
auto pack_qk = mTempQK->host<char>() + tId * (eP + 2) * kv_seq_len * bytes;
|
||||
auto unpack_qk = pack_qk + kv_seq_len * eP * bytes;
|
||||
auto mask_qk = reinterpret_cast<float*>(pack_qk);
|
||||
auto softmax_qk = reinterpret_cast<float*>(unpack_qk);
|
||||
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * unit * bytes;
|
||||
|
||||
int head_index = tId * tileCount;
|
||||
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
|
||||
auto key_dst = mPastKey->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes;
|
||||
auto value_dst = mPastValue->host<char>() + h * mValueH * mMaxLength * hP * bytes;
|
||||
// pack for matmul
|
||||
if (bytes == 2) {
|
||||
decode_pack<int16_t>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mPastLength, mHeadDim, mValueH, eP, hP, h);
|
||||
} else {
|
||||
decode_pack<float>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mPastLength, mHeadDim, mValueH, eP, hP, h);
|
||||
}
|
||||
// query @ key: [1, head_dim] @ [head_dim, kv_seq_len] -> [1, kv_seq_len]
|
||||
size_t shapeParameters[6];
|
||||
size_t* parameters = shapeParameters;
|
||||
parameters[0] = eP * bytes;
|
||||
parameters[1] = mHeadDim;
|
||||
parameters[2] = kv_seq_len;
|
||||
parameters[3] = seq_len * unit * bytes;
|
||||
parameters[4] = 0;
|
||||
parameters[5] = 0;
|
||||
matmulRemain((float*)pack_qk, (float*)pack_q, (float*)key_dst, seq_len, parameters, nullptr, nullptr, nullptr, nullptr);
|
||||
int area_offset[1] {seq_len};
|
||||
core->MNNUnpackCUnitTranspose((float*)unpack_qk, (float*)pack_qk, seq_len, kv_seq_len, area_offset);
|
||||
if (bytes == 2) {
|
||||
decode_softmax<FLOAT16_T>(mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, kv_seq_len);
|
||||
} else {
|
||||
decode_softmax<float>(mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, kv_seq_len);
|
||||
}
|
||||
// qk @ v: [1, kv_seq_len] @ [kv_seq_len, head_dim] -> [1, head_dim]
|
||||
{
|
||||
size_t shapeParameters[6];
|
||||
size_t* parameters = shapeParameters;
|
||||
parameters[0] = eP * bytes;
|
||||
parameters[1] = kv_seq_len;
|
||||
parameters[2] = mHeadDim;
|
||||
parameters[3] = 1 * unit * bytes;
|
||||
parameters[5] = (mMaxLength - kv_seq_len) * hP * bytes;
|
||||
matmulRemain((float*)pack_qkv, (float*)pack_qk, (float*)value_dst, 1, parameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
// transpose: [head_dim/unit, 1, unit] -> [1, num_head, head_dim]
|
||||
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
|
||||
core->MNNUnpackCUnitTranspose((float*)dst_ptr, (float*)pack_qkv, 1, mHeadDim, area_offset);
|
||||
}
|
||||
};
|
||||
mFunction = mIsDecode ? mDecode : mPrefill;
|
||||
reallocKVCache();
|
||||
// compute
|
||||
MNN_CONCURRENCY_BEGIN(tId, mThreadNum) {
|
||||
mFunction((int)tId);
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
mPastLength += mIsDecode;
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
CPUAttention::CPUAttention(Backend* backend, bool kv_cahce) : Execution(backend) {
|
||||
mImpl.reset(new CPUAttentionImpl(backend, kv_cahce));
|
||||
}
|
||||
|
||||
CPUAttention::CPUAttention(std::shared_ptr<CPUAttentionImpl> impl, Backend *backend) : Execution(backend), mImpl(impl) {}
|
||||
|
||||
ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
return mImpl->onResize(backend(), inputs, outputs);
|
||||
}
|
||||
|
||||
ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
return mImpl->onExecute(backend(), inputs, outputs);
|
||||
}
|
||||
|
||||
bool CPUAttention::onClone(Backend* bn, const Op* op, Execution** dst) {
|
||||
if (nullptr == dst) {
|
||||
return true;
|
||||
}
|
||||
*dst = new CPUAttention(mImpl, bn);
|
||||
return true;
|
||||
}
|
||||
|
||||
class CPUAttentionCreator : public CPUBackend::Creator {
|
||||
public:
|
||||
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
||||
const MNN::Op* op, Backend* backend) const override {
|
||||
auto param = op->main_as_AttentionParam();
|
||||
return new CPUAttention(backend, param->kv_cache());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CPU_OP_CREATOR_TRANSFORMER(CPUAttentionCreator, OpType_Attention);
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
//
|
||||
// CPUAttention.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2024/03/19.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
|
||||
#ifndef CPUATTENTION_HPP
|
||||
#define CPUATTENTION_HPP
|
||||
|
||||
#include <functional>
|
||||
#include "core/Execution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
|
||||
class CPUAttentionImpl {
|
||||
public:
|
||||
CPUAttentionImpl(Backend *backend, bool kv_cache) : mBackend(backend), mKVCache(kv_cache) {}
|
||||
~CPUAttentionImpl() = default;
|
||||
ErrorCode onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
|
||||
ErrorCode onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
|
||||
private:
|
||||
void allocKVCache();
|
||||
void reallocKVCache();
|
||||
Backend* backend() { return mBackend; }
|
||||
private:
|
||||
Backend* mBackend;
|
||||
bool mKVCache;
|
||||
float mScale;
|
||||
const int mExpandChunk = 64;
|
||||
int mThreadNum = 1;
|
||||
bool mIsDecode = false;
|
||||
int mPastLength = 0, mMaxLength = 0;
|
||||
std::shared_ptr<Tensor> mPastKey, mPastValue, mTempQK;
|
||||
std::shared_ptr<Tensor> mPackQ, mPackQKV;
|
||||
int mNumHead = 0, mHeadDim = 0, mValueH = 0;
|
||||
int eP, lP, hP, bytes;
|
||||
std::function<void(int)> mFunction, mPrefill, mDecode;
|
||||
};
|
||||
|
||||
class CPUAttention : public Execution {
|
||||
public:
|
||||
CPUAttention(Backend *backend, bool kv_cache);
|
||||
CPUAttention(std::shared_ptr<CPUAttentionImpl> impl, Backend *backend);
|
||||
virtual ~CPUAttention() = default;
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
private:
|
||||
std::shared_ptr<CPUAttentionImpl> mImpl;
|
||||
};
|
||||
} // namespace MNN
|
||||
|
||||
#endif // CPUATTENTION_HPP
|
||||
#endif
|
||||
|
|
@ -93,7 +93,7 @@ public:
|
|||
virtual void* onMapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* srcTensor) override;
|
||||
|
||||
virtual bool onUnmapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* dstTensor, void* mapPtr) override;
|
||||
|
||||
|
||||
virtual void onResizeBegin() override;
|
||||
virtual ErrorCode onResizeEnd() override;
|
||||
|
||||
|
|
@ -194,6 +194,12 @@ private:
|
|||
CPUBackend::addCreator(opType, &_temp); \
|
||||
}
|
||||
|
||||
#define REGISTER_CPU_OP_CREATOR_TRANSFORMER(name, opType) \
|
||||
void ___##name##__##opType##__() { \
|
||||
static name _temp;\
|
||||
CPUBackend::addCreator(opType, &_temp); \
|
||||
}
|
||||
|
||||
} // namespace MNN
|
||||
|
||||
#endif /* CPUBackend_hpp */
|
||||
|
|
|
|||
|
|
@ -105,6 +105,10 @@ ErrorCode CPULayerNorm::onResize(const std::vector<Tensor*> &inputs,
|
|||
mInnerSize *= inputs.at(0)->length(i);
|
||||
}
|
||||
mInnerSize /= mResource->mGroup;
|
||||
if (mResource->mIniGammaBeta) {
|
||||
MNN_ASSERT(mResource->mGamma->size() == mInnerSize * sizeof(float));
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
for (int i = 0; i < rank - mResource->mAxis; ++i) {
|
||||
|
|
@ -113,6 +117,9 @@ ErrorCode CPULayerNorm::onResize(const std::vector<Tensor*> &inputs,
|
|||
for (int i = rank - mResource->mAxis; i < rank; ++i) {
|
||||
mInnerSize *= inputs.at(0)->length(i);
|
||||
}
|
||||
if (mResource->mIniGammaBeta) {
|
||||
MNN_ASSERT(mResource->mGamma->size() == mInnerSize * sizeof(float));
|
||||
}
|
||||
if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) {
|
||||
mInpZero.resize(1);
|
||||
mOutZero.resize(1);
|
||||
|
|
|
|||
|
|
@ -75,6 +75,9 @@ extern void ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
|
|||
extern void ___CPURasterDiffCreator__OpType_RasterDiff__();
|
||||
extern void ___CPUTextureCreator__OpType_Texture__();
|
||||
#endif
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
extern void ___CPUAttentionCreator__OpType_Attention__();
|
||||
#endif
|
||||
void registerCPUOps() {
|
||||
___CPUCropAndResizeCreator__OpType_CropAndResize__();
|
||||
___CPUArgMaxCreator__OpType_ArgMax__();
|
||||
|
|
@ -150,5 +153,8 @@ ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
|
|||
___CPURasterDiffCreator__OpType_RasterDiff__();
|
||||
___CPUTextureCreator__OpType_Texture__();
|
||||
#endif
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
___CPUAttentionCreator__OpType_Attention__();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#include "./FunctionSummary.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
|
||||
extern "C" {
|
||||
void MNNTranspose32Bit4x4(int32_t* dstO, const int32_t* srcO, int32_t* dim);
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ asm_function MNNGemmHybridInt4FP32
|
|||
// int32_t useInt8;
|
||||
//};
|
||||
|
||||
//void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
//void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
|
||||
|
||||
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
|
||||
|
|
@ -120,8 +120,6 @@ LoopSz_TILE_4:
|
|||
// int4->int8
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
|
||||
Unit_TILE_4:
|
||||
|
|
@ -244,8 +242,6 @@ LoopSz_TILE_1:
|
|||
// int4->int8
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
|
||||
Unit_TILE_1:
|
||||
|
|
@ -261,7 +257,7 @@ LoopSz_TILE_1:
|
|||
smlal v11.4s, v5.4h, v29.4h
|
||||
smlal v12.4s, v5.4h, v30.4h
|
||||
smlal v13.4s, v5.4h, v31.4h
|
||||
|
||||
|
||||
//.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0]
|
||||
|
||||
subs x26, x26, #1
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ asm_function MNNGemmHybridInt4FP32_sdot
|
|||
// int32_t useInt8;
|
||||
//};
|
||||
|
||||
//void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
//void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
|
||||
|
||||
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
|
||||
|
|
@ -99,11 +99,9 @@ LoopDz_TILE_12:
|
|||
movi v25.4s, #0
|
||||
movi v26.4s, #0
|
||||
movi v27.4s, #0
|
||||
|
||||
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
LoopSz_TILE_12:
|
||||
// src : 4(batch) x [1 x 4] : v4
|
||||
// weight : 4(oc) x [1 x 4] : v0
|
||||
|
|
@ -113,8 +111,6 @@ LoopSz_TILE_12:
|
|||
// int4->int8
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
|
||||
.inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
|
||||
|
|
@ -203,11 +199,9 @@ LoopDz_TILE_8:
|
|||
movi v21.4s, #0
|
||||
movi v22.4s, #0
|
||||
movi v23.4s, #0
|
||||
|
||||
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
LoopSz_TILE_8:
|
||||
// src : 4(batch) x [1 x 4] : v4
|
||||
// weight : 4(oc) x [1 x 4] : v0
|
||||
|
|
@ -217,8 +211,6 @@ LoopSz_TILE_8:
|
|||
// int4->int8
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
|
||||
.inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
|
||||
|
|
@ -294,8 +286,6 @@ LoopDz_TILE_4:
|
|||
dup v19.4s, wzr
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
LoopSz_TILE_4:
|
||||
// src : 4(batch) x [1 x 4] : v4
|
||||
// weight : 4(oc) x [1 x 4] : v0
|
||||
|
|
@ -305,8 +295,6 @@ LoopSz_TILE_4:
|
|||
// int4->int8
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
|
||||
.inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
|
||||
|
|
@ -368,8 +356,6 @@ LoopDz_TILE_1:
|
|||
dup v16.4s, wzr
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
LoopSz_TILE_1:
|
||||
// src : 1(batch) x [1 x 4] : v4
|
||||
// weight : 4(oc) x [1 x 4] : v0
|
||||
|
|
@ -379,10 +365,8 @@ LoopSz_TILE_1:
|
|||
// int4->int8
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
|
||||
|
||||
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0]
|
||||
|
||||
subs x26, x26, #1
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ asm_function MNNGemmHybridInt4FP32_smmla
|
|||
// int32_t useInt8;
|
||||
//};
|
||||
|
||||
//void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
//void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
|
||||
|
||||
|
||||
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
|
||||
|
|
@ -106,28 +106,18 @@ LoopDz_TILE_8:
|
|||
|
||||
// mask
|
||||
movi v10.16b, #15
|
||||
// offset
|
||||
movi v11.16b, #8
|
||||
LoopSz_TILE_8:
|
||||
// src : 2 x [2 x 8] : v4-5
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 2 x 4 x [4] : v16-23
|
||||
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
ld1 {v12.16b, v13.16b, v14.16b, v15.16b}, [x24], x15 // src
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v4.16b, v0.16b, #4
|
||||
and v5.16b, v0.16b, v10.16b
|
||||
sub v4.16b, v4.16b, v11.16b
|
||||
sub v5.16b, v5.16b, v11.16b
|
||||
ushr v8.16b, v1.16b, #4
|
||||
and v9.16b, v1.16b, v10.16b
|
||||
sub v8.16b, v8.16b, v11.16b
|
||||
sub v9.16b, v9.16b, v11.16b
|
||||
zip1 v0.16b, v4.16b, v5.16b
|
||||
zip2 v1.16b, v4.16b, v5.16b
|
||||
zip1 v2.16b, v8.16b, v9.16b
|
||||
zip2 v3.16b, v8.16b, v9.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v10.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v10.16b
|
||||
|
||||
.inst 0x4e80a590 // smmla v16.4s, v12.16b, v0.16b
|
||||
.inst 0x4e81a591 // smmla v17.4s, v12.16b, v1.16b
|
||||
|
|
@ -247,27 +237,17 @@ LoopDz_TILE_4:
|
|||
dup v23.4s, wzr
|
||||
// mask
|
||||
movi v10.16b, #15
|
||||
// offset
|
||||
movi v11.16b, #8
|
||||
LoopSz_TILE_4:
|
||||
// src : 2 x [2 x 8] : v4-5
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 2 x 4 x [4] : v16-23
|
||||
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v4.16b, v0.16b, #4
|
||||
and v5.16b, v0.16b, v10.16b
|
||||
sub v4.16b, v4.16b, v11.16b
|
||||
sub v5.16b, v5.16b, v11.16b
|
||||
ushr v6.16b, v1.16b, #4
|
||||
and v7.16b, v1.16b, v10.16b
|
||||
sub v6.16b, v6.16b, v11.16b
|
||||
sub v7.16b, v7.16b, v11.16b
|
||||
zip1 v0.16b, v4.16b, v5.16b
|
||||
zip2 v1.16b, v4.16b, v5.16b
|
||||
zip1 v2.16b, v6.16b, v7.16b
|
||||
zip2 v3.16b, v6.16b, v7.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v10.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v10.16b
|
||||
ld1 {v4.16b, v5.16b}, [x24], x15 // src
|
||||
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
|
||||
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
|
||||
|
|
@ -349,27 +329,17 @@ LoopDz_TILE_2:
|
|||
dup v19.4s, wzr
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
LoopSz_TILE_2:
|
||||
// src : 1 x [2 x 8] : v4
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 1 x 4 x [4] : v16-19
|
||||
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
ushr v10.16b, v1.16b, #4
|
||||
and v11.16b, v1.16b, v14.16b
|
||||
sub v10.16b, v10.16b, v15.16b
|
||||
sub v11.16b, v11.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
zip2 v1.16b, v8.16b, v9.16b
|
||||
zip1 v2.16b, v10.16b, v11.16b
|
||||
zip2 v3.16b, v10.16b, v11.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v14.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v14.16b
|
||||
ld1 {v4.16b}, [x24], x15 // src
|
||||
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
|
||||
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
|
||||
|
|
@ -438,28 +408,18 @@ LoopDz_TILE_1:
|
|||
dup v19.4s, wzr
|
||||
// mask
|
||||
movi v14.16b, #15
|
||||
// offset
|
||||
movi v15.16b, #8
|
||||
|
||||
LoopSz_TILE_1:
|
||||
// src : 1 x [1 x 8] : v4
|
||||
// weight : 4 x [2 x 8] : v0-3
|
||||
// dst : 1 x 4 x [2] : v16-v19
|
||||
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
|
||||
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
|
||||
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
|
||||
// int4 to int8: v0, v1, v2, v3
|
||||
ushr v8.16b, v0.16b, #4
|
||||
and v9.16b, v0.16b, v14.16b
|
||||
sub v8.16b, v8.16b, v15.16b
|
||||
sub v9.16b, v9.16b, v15.16b
|
||||
ushr v10.16b, v1.16b, #4
|
||||
and v11.16b, v1.16b, v14.16b
|
||||
sub v10.16b, v10.16b, v15.16b
|
||||
sub v11.16b, v11.16b, v15.16b
|
||||
zip1 v0.16b, v8.16b, v9.16b
|
||||
zip2 v1.16b, v8.16b, v9.16b
|
||||
zip1 v2.16b, v10.16b, v11.16b
|
||||
zip2 v3.16b, v10.16b, v11.16b
|
||||
ushr v0.16b, v8.16b, #4
|
||||
and v1.16b, v8.16b, v14.16b
|
||||
ushr v2.16b, v9.16b, #4
|
||||
and v3.16b, v9.16b, v14.16b
|
||||
ld1 {v4.8b}, [x24], x15 // src
|
||||
.inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
|
||||
.inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@
|
|||
#include "math/Vec.hpp"
|
||||
#include <vector>
|
||||
#include "../CPURuntime.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
// TODO: Find better way to optimize it
|
||||
#include "../CPUBinary.hpp"
|
||||
#include "../CPUUnary.hpp"
|
||||
|
|
@ -376,7 +376,7 @@ void MNN1BitCopyFast (uint8_t* dstO, const uint8_t* srcO, int size, int stride,
|
|||
#ifdef MNN_USE_SSE
|
||||
std::vector<uint8_t> arr(16, val);
|
||||
auto val16 = _mm_loadu_ps((float*)arr.data());
|
||||
|
||||
|
||||
for (; cnt >= 16; cnt-=16) {
|
||||
_mm_storeu_ps((float*)dstO, val16);
|
||||
dstO += 16;
|
||||
|
|
@ -427,7 +427,7 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
|
|||
if (size >= 8) {
|
||||
auto sum4_1 = _mm_set_ps1(0.f);
|
||||
auto sum4_2 = _mm_set_ps1(0.f);
|
||||
|
||||
|
||||
for (; i < size8; i += 8) {
|
||||
auto v4 = _mm_loadu_ps(src);
|
||||
auto u4 = _mm_loadu_ps(src + 4);
|
||||
|
|
@ -435,7 +435,7 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
|
|||
sum4_2 = _mm_add_ps(sum4_2, u4);
|
||||
src += 8;
|
||||
}
|
||||
|
||||
|
||||
sum4_1 = _mm_add_ps(sum4_1, sum4_2);
|
||||
_mm_storeu_ps(tmp, sum4_1);
|
||||
sum += (tmp[0] + tmp[1] + tmp[2] + tmp[3]);
|
||||
|
|
@ -823,7 +823,7 @@ void MNNGemmHybridInt8FP32(float* C, const int8_t* A, const int8_t* B, size_t sr
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// int32->float
|
||||
for (int cn = 0; cn < pack; ++cn) {
|
||||
float val = (float)tmp[cn] * scale[0];
|
||||
|
|
@ -873,7 +873,7 @@ void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t sr
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// int32->float
|
||||
for (int cn = 0; cn < pack; ++cn) {
|
||||
float val = (float)tmp[cn] * scale[0];
|
||||
|
|
@ -3333,7 +3333,7 @@ void MNNCoreFunctionInit() {
|
|||
gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg<float, Vec4, 4>);
|
||||
// Set min value as 1 << 24
|
||||
gCoreFunction->MNNPoolingMax = (decltype(gCoreFunction->MNNPoolingMax))(poolingMax<float, Vec4, 4, -16777216>);
|
||||
|
||||
|
||||
gCoreFunction->MNNPoolingMaxWithRedice = (decltype(gCoreFunction->MNNPoolingMaxWithRedice))(poolingMaxWithRedice<float, -16777216>);
|
||||
// ImageProcess Functions
|
||||
gCoreFunction->MNNRGBAToBGRA = MNNRGBAToBGRA;
|
||||
|
|
@ -3353,7 +3353,7 @@ void MNNCoreFunctionInit() {
|
|||
gCoreFunction->MNN4BitcopyFast = MNN4BitcopyFast;
|
||||
gCoreFunction->MNN2BitcopyFast = MNN2BitcopyFast;
|
||||
gCoreFunction->MNN1BitcopyFast = MNN1BitCopyFast;
|
||||
|
||||
|
||||
gCoreFunction->MNNAccumulateSequenceNumber = MNNAccumulateSequenceNumber;
|
||||
|
||||
cpuinfo_arm_isa gCPUInfo;
|
||||
|
|
|
|||
|
|
@ -33,19 +33,7 @@ bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr<ConvolutionCommon::
|
|||
resource->lU = lU;
|
||||
resource->hP = hP;
|
||||
resource->lP = lP;
|
||||
// Reorder weight
|
||||
auto dstWInt8 = resource->mWeight->host<int8_t>();
|
||||
auto srcWInt8 = int8Info->weight.get();
|
||||
// oc, ic -> oc/hP, ic/lP, hP, lP
|
||||
for (int i = 0; i < hU; i++) {
|
||||
for (int j = 0; j < lU; j++) {
|
||||
for (int k = 0; k < hP; k++) {
|
||||
for (int l = 0; l < lP; l++) {
|
||||
dstWInt8[i * srcChannel * hP + j * hP * lP + k * lP + l] = srcWInt8[(i * hP + k) * srcChannel + (j * lP + l)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save scale bias
|
||||
resource->mDequantize.mScaleBias.reset(MNN::Tensor::createDevice<float>({hU * hP * 2}));
|
||||
res = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC);
|
||||
|
|
@ -56,6 +44,12 @@ bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr<ConvolutionCommon::
|
|||
auto biasPtr = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(alphaPtr) + hU * hP * bytes);
|
||||
::memset(alphaPtr, 0, 2 * hU * hP * bytes);
|
||||
int h = int8Info->alpha.size();
|
||||
if (int8Info->canUseInt4) {
|
||||
// int4 to uint4, -8 offset merge to bias
|
||||
for (int i = 0; i < h/2; ++i) {
|
||||
int8Info->alpha.get()[2 * i] -= 8 * int8Info->alpha.get()[2 * i + 1];
|
||||
}
|
||||
}
|
||||
if (bytes == 2) {
|
||||
auto core = static_cast<CPUBackend*>(resource->backend)->functions();
|
||||
if (int8Info->asymmetric) {
|
||||
|
|
@ -86,21 +80,53 @@ bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr<ConvolutionCommon::
|
|||
MNN_ASSERT(weightLength % 2 == 0);
|
||||
weightLength = UP_DIV(weightLength, 2);
|
||||
resource->mDequantize.bits = 4;
|
||||
std::shared_ptr<MNN::Tensor> weightLow(Tensor::createDevice<uint8_t>(
|
||||
{weightLength}));
|
||||
auto res = resource->backend->onAcquireBuffer(weightLow.get(), Backend::STATIC);
|
||||
if (!res) {
|
||||
return false;
|
||||
|
||||
auto srcPtr = int8Info->weight.get();
|
||||
auto dstPtr = resource->mWeight->host<uint8_t>();
|
||||
// oc, ic -> oc/hP, ic/lP, hP, lP
|
||||
if (hP == 8 && lP == 8) {
|
||||
for (int i = 0; i < hU; i++) {
|
||||
for (int j = 0; j < lU; j++) {
|
||||
for (int k = 0; k < 2; k++) {
|
||||
for (int n = 0; n < 16; n++) {
|
||||
int hp_idx = n / 8;
|
||||
int lp_idx = n % 8;
|
||||
int s0 = srcPtr[(i * hP + k * 4 + hp_idx) * srcChannel + (j * lP + lp_idx)];
|
||||
int s1 = srcPtr[(i * hP + k * 4 + hp_idx + 2) * srcChannel + (j * lP + lp_idx)];
|
||||
int d = (s0 + 8) * 16 + (s1 + 8);
|
||||
dstPtr[(i * srcChannel * hP + j * hP * lP + k * 32) / 2 + n] = (uint8_t)d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < hU; i++) {
|
||||
for (int j = 0; j < lU; j++) {
|
||||
for (int k = 0; k < hP; k++) {
|
||||
for (int l = 0; l < lP; l+=2) {
|
||||
int s0 = srcPtr[(i * hP + k) * srcChannel + (j * lP + l)];
|
||||
int s1 = srcPtr[(i * hP + k) * srcChannel + (j * lP + l + 1)];
|
||||
int d = (s0 + 8) * 16 + (s1 + 8);
|
||||
dstPtr[(i * srcChannel * hP + j * hP * lP + k * lP + l) / 2] = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto srcPtr = resource->mWeight->host<int8_t>();
|
||||
auto dstPtr = weightLow->host<uint8_t>();
|
||||
for (int i=0; i < weightLength; ++i) {
|
||||
int s0 = srcPtr[2 * i + 0];
|
||||
int s1 = srcPtr[2 * i + 1];
|
||||
int d = (s0 + 8) * 16 + (s1 + 8);
|
||||
dstPtr[i] = d;
|
||||
} else {
|
||||
// Reorder weight for int8
|
||||
auto dstWInt8 = resource->mWeight->host<int8_t>();
|
||||
auto srcWInt8 = int8Info->weight.get();
|
||||
// oc, ic -> oc/hP, ic/lP, hP, lP
|
||||
for (int i = 0; i < hU; i++) {
|
||||
for (int j = 0; j < lU; j++) {
|
||||
for (int k = 0; k < hP; k++) {
|
||||
for (int l = 0; l < lP; l++) {
|
||||
dstWInt8[i * srcChannel * hP + j * hP * lP + k * lP + l] = srcWInt8[(i * hP + k) * srcChannel + (j * lP + l)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resource->mWeight = weightLow;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -129,10 +155,10 @@ ConvolutionHybrid::ConvolutionHybrid(const Convolution2DCommon *common, Backend
|
|||
hPack = unit;
|
||||
lPack = unit;
|
||||
// [oc, ic] => [oc/unit, ic/src_unit, unit, src_unit]
|
||||
if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla.
|
||||
hPack = 8;
|
||||
lPack = 8;
|
||||
}
|
||||
if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla.
|
||||
hPack = 8;
|
||||
lPack = 8;
|
||||
}
|
||||
auto hU = UP_DIV(outputCount, hPack);
|
||||
auto lU = UP_DIV(inputCount, lPack);
|
||||
ConvolutionHybrid::initQuantizeResource(quantInfo, mResource, hU, hPack, lU, lPack, outputCount, (int)originWeightSize / (int)biasSize, common->kernelX() * common->kernelY(), core->bytes);
|
||||
|
|
@ -237,8 +263,8 @@ ErrorCode ConvolutionHybrid::onResize(const std::vector<Tensor *> &inputs, const
|
|||
int iTileCount = UP_DIV(iC4, threadNumber);
|
||||
if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla.
|
||||
ANeedToPack8 = true;
|
||||
}
|
||||
int8_t order[32] = {0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 28, 29, 30, 31, 8, 9, 10, 11, 4, 5, 6, 7, 24, 25, 26, 27, 20, 21, 22, 23};
|
||||
}
|
||||
int8_t order[32] = {0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 28, 29, 30, 31, 8, 9, 10, 11, 4, 5, 6, 7, 24, 25, 26, 27, 20, 21, 22, 23};
|
||||
allocDynamicQuantInfo(threadNumber, batch, ic, oc, bytes);
|
||||
mDynamicQuant = [=]() {
|
||||
auto maxPtr = mQuantInfo.quant_info.host<uint8_t>();
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
#include "core/TensorUtils.hpp"
|
||||
#include "math/WingoradGenerater.hpp"
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
#ifdef MNN_USE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
#include "core/TensorUtils.hpp"
|
||||
#include "math/WingoradGenerater.hpp"
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
|
||||
constexpr int FULSE_THRESHHOLD_NUMERATOR = 10;
|
||||
constexpr int FULSE_THRESHHOLD_DENOMINATOR = 10;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
#include "core/TensorUtils.hpp"
|
||||
#include "math/Vec.hpp"
|
||||
#include "core/BufferAllocator.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
|
||||
using Vec4 = MNN::Math::Vec<float, 4>;
|
||||
namespace MNN {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
#include "core/TensorUtils.hpp"
|
||||
#include "math/WingoradGenerater.hpp"
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
|
||||
|
||||
//#define MNN_WINOGRAD_PRINT_REDUCE_RATE
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
#include "core/TensorUtils.hpp"
|
||||
#include "math/Vec.hpp"
|
||||
#include "core/BufferAllocator.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
#define PARAMETERSIZE 6
|
||||
|
||||
using Vec4 = MNN::Math::Vec<float, 4>;
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include <cstring> // for memset
|
||||
#include "Int8FunctionsOpt.h"
|
||||
#include "core/Macro.h"
|
||||
#include "common/CommonCompute.hpp"
|
||||
#include "core/CommonCompute.hpp"
|
||||
#include "CommonOptFunction.h"
|
||||
#include "math/Vec.hpp"
|
||||
|
||||
|
|
@ -1537,7 +1537,7 @@ static void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src,
|
|||
}
|
||||
|
||||
for (int i = 0; i < pack; ++i) {
|
||||
|
||||
|
||||
float val = (dstInt32[i] + bias_z[i]) * scale_z[i];
|
||||
int valOut = roundf(val) + offset;
|
||||
if (valOut > parameters->maxValue + offset) {
|
||||
|
|
@ -1615,7 +1615,7 @@ void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWi
|
|||
for (int y = 0; y < kernely; ++y) {
|
||||
for (int x = 0; x < kernelx; ++x) {
|
||||
const int8_t* inputPtr = srcPtr + pack* (x + inputWidth* y);
|
||||
for (int idx = 0; idx < pack; ++idx) {
|
||||
for (int idx = 0; idx < pack; ++idx) {
|
||||
results[idx] = std::max(results[idx], *(inputPtr + idx));
|
||||
}
|
||||
}
|
||||
|
|
@ -2125,7 +2125,7 @@ void MNNCoreInt8FunctionInit() {
|
|||
// pooling
|
||||
gCoreFunc->MNNAvgPoolInt8 = MNNAvgPoolInt8;
|
||||
gCoreFunc->MNNMaxPoolInt8 = MNNMaxPoolInt8;
|
||||
|
||||
|
||||
// Norm
|
||||
gCoreFunc->MNNNormInt8 = MNNNormInt8;
|
||||
|
||||
|
|
@ -2143,7 +2143,7 @@ void MNNCoreInt8FunctionInit() {
|
|||
gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4<12, 4>;
|
||||
// ConvDepthwise
|
||||
gCoreFunc->ConvDepthwise3x3LineInt8_ARM82 = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3;
|
||||
|
||||
|
||||
}
|
||||
if (core->supportI8mm) {
|
||||
// MatMul
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
#include "CommonOptFunction.h"
|
||||
#include "core/Concurrency.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
#include "MNN/AutoTime.hpp"
|
||||
#include <math.h>
|
||||
#ifdef MNN_USE_SSE
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@
|
|||
#include "core/TensorUtils.hpp"
|
||||
#include "math/Vec.hpp"
|
||||
#include "core/BufferAllocator.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "common/CommonCompute.hpp"
|
||||
#include "core/MemoryFormater.h"
|
||||
#include "core/CommonCompute.hpp"
|
||||
|
||||
using Vec4 = MNN::Math::Vec<float, 4>;
|
||||
namespace MNN {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
#include <map>
|
||||
#include "core/Macro.h"
|
||||
#include "math/Vec.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
|
||||
using Vec4 = MNN::Math::Vec<float, 4>;
|
||||
#define DEFAULT_UNIT 8
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t s
|
|||
const float* scale_ptr = param[4];
|
||||
auto one_int16 = _mm256_set1_epi16(1);
|
||||
auto offset_int8 = _mm256_set1_epi8(128);
|
||||
auto _int4_signed_8 = _mm256_set1_ps(8);
|
||||
for (int ci = 0; ci < dst_depth_quad; ++ci) {
|
||||
float* dstZ = C + ci * pack * realSize;
|
||||
const int8_t* weight = B + ci * weight_step;
|
||||
|
|
@ -83,14 +82,13 @@ void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t s
|
|||
auto zero = zero_ptr + ci * pack;
|
||||
auto bias = bias_ptr + ci * pack;
|
||||
__m256 alphaValue = _mm256_loadu_ps(alpha);
|
||||
auto extra_sum = _mm256_mul_ps(_int4_signed_8, alphaValue);
|
||||
for (int j = 0; j < realSize; ++j) {
|
||||
const float* sums = sums_ptr + j;
|
||||
const float* scale = scale_ptr + j;
|
||||
float* dstX = dstZ + j * pack;
|
||||
__m256 scaleValue = _mm256_set1_ps(scale[0]);
|
||||
auto sum_val = _mm256_set1_ps(sums[0]);
|
||||
__m256 biasValue = _mm256_add_ps(_mm256_loadu_ps(bias), _mm256_mul_ps(_mm256_sub_ps(_mm256_loadu_ps(zero), extra_sum), sum_val));
|
||||
__m256 biasValue = _mm256_add_ps(_mm256_loadu_ps(bias), _mm256_mul_ps(_mm256_loadu_ps(zero), sum_val));
|
||||
const int8_t* srcBatch = A + j * pack;
|
||||
auto oc0123_int16 = _mm256_set1_epi16(0);
|
||||
auto oc4567_int16 = _mm256_set1_epi16(0);
|
||||
|
|
@ -103,10 +101,8 @@ void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t s
|
|||
const uint8_t* weightZ = (uint8_t*)weight + k * weight_stride;
|
||||
auto s0 = _mm256_castpd_si256(_mm256_broadcast_sd((double*)srcZ));
|
||||
auto wi4 = _mm256_castps_si256(_mm256_loadu_ps((const float*)weightZ));
|
||||
auto w_high = _mm256_and_si256(mask, _mm256_srli_epi16(wi4, 4));
|
||||
auto w_low = _mm256_and_si256(mask, wi4);
|
||||
auto w0_ = _mm256_unpacklo_epi8(w_high, w_low);
|
||||
auto w1_ = _mm256_unpackhi_epi8(w_high, w_low);
|
||||
auto w0_ = _mm256_and_si256(mask, _mm256_srli_epi16(wi4, 4));
|
||||
auto w1_ = _mm256_and_si256(mask, wi4);
|
||||
auto w0 = _mm256_permute2x128_si256(w0_, w1_, 0x20);
|
||||
auto w1 = _mm256_permute2x128_si256(w0_, w1_, 0x31);
|
||||
oc0123_int16 = _mm256_maddubs_epi16(w0, s0); // int16_t sum
|
||||
|
|
@ -194,9 +190,11 @@ void _AVX_MNNGemmHybridInt8(float* C, const int8_t* A, const int8_t* B, size_t s
|
|||
auto oc_04261537 = _mm256_add_epi32(oc_04261537_lo, oc_04261537_hi);
|
||||
auto oc_0426 = _mm256_extractf128_si256(oc_04261537, 0);
|
||||
auto oc_1537 = _mm256_extractf128_si256(oc_04261537, 1);
|
||||
auto oc_0123 = _mm_unpacklo_epi32(oc_0426, oc_1537);
|
||||
auto oc_4567 = _mm_unpackhi_epi32(oc_0426, oc_1537);
|
||||
auto sum8 = _mm256_set_m128i(oc_0123, oc_4567);
|
||||
auto oc_0145 = _mm_unpacklo_epi32(oc_0426, oc_1537);
|
||||
auto oc_2367 = _mm_unpackhi_epi32(oc_0426, oc_1537);
|
||||
auto oc_0123 = _mm_unpacklo_epi64(oc_0145, oc_2367);
|
||||
auto oc_4567 = _mm_unpackhi_epi64(oc_0145, oc_2367);
|
||||
auto sum8 = _mm256_set_m128i(oc_4567, oc_0123);
|
||||
__m256 f0 = _mm256_cvtepi32_ps(sum8);
|
||||
__m256 fs = _mm256_mul_ps(_mm256_mul_ps(f0, scaleValue), alphaValue);
|
||||
fs = _mm256_add_ps(biasValue, fs);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include <map>
|
||||
#include "Vec8.hpp"
|
||||
#include "FunctionSummary.hpp"
|
||||
#include "common/MemoryFormater.h"
|
||||
#include "core/MemoryFormater.h"
|
||||
#define PACK_UNIT 8
|
||||
namespace MNN {
|
||||
|
||||
|
|
|
|||
|
|
@ -656,8 +656,8 @@ static inline __m128i _load_int4_to_int8(const uint8_t* src) {
|
|||
int32_t data[4];
|
||||
int8_t temp[16];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
temp[2 * i] = (src[i] >> 4) - 8;
|
||||
temp[2 * i +1] = (src[i] & c) - 8;
|
||||
temp[2 * i] = (src[i] >> 4);
|
||||
temp[2 * i +1] = (src[i] & c);
|
||||
}
|
||||
auto int8_tx16 = _mm_loadu_si128((const __m128i*)temp);
|
||||
return int8_tx16;
|
||||
|
|
|
|||
|
|
@ -1,32 +1,27 @@
|
|||
#include "AllShader.hpp"
|
||||
const char* shader_MetalReLU6_metal =
|
||||
"kernel void relu6_x1(const device M *in [[buffer(0)]],\n"
|
||||
" device M *out [[buffer(1)]],\n"
|
||||
" constant float4 &minMax [[buffer(2)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" out[int(gid)]=clamp(in[int(gid)],(M)(minMax.x),(M)(minMax.y));\n"
|
||||
"}\n"
|
||||
"kernel void relu6_x4(const device M4 *in [[buffer(0)]],\n"
|
||||
"struct Param {\n"
|
||||
" float minV;\n"
|
||||
" float maxV;\n"
|
||||
" int size;\n"
|
||||
" int remain;\n"
|
||||
"};\n"
|
||||
"kernel void relu6(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant float4 &minMax [[buffer(2)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" out[int(gid)]=clamp(in[int(gid)],(M4)minMax.x,(M4)minMax.y);\n"
|
||||
" constant Param &p [[buffer(2)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if (gid.x<p.size) {\n"
|
||||
" out[int(gid.x)]=clamp(in[int(gid.x)],(M4)p.minV,(M4)p.maxV);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
;
|
||||
const char* shader_MetalReLU_metal =
|
||||
"kernel void relu_x1(const device M *in [[buffer(0)]],\n"
|
||||
" device M *out [[buffer(1)]],\n"
|
||||
" constant float &slope [[buffer(2)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" auto V=in[int(gid)];\n"
|
||||
" out[int(gid)]=fmax(V,(M)0)+fmin(V,(M)0)*M(slope);\n"
|
||||
"}\n"
|
||||
"kernel void relu_x4(const device M4 *in [[buffer(0)]],\n"
|
||||
"kernel void relu(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant float &slope [[buffer(2)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" auto V=in[int(gid)];\n"
|
||||
" out[int(gid)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(slope);\n"
|
||||
" constant Param &p [[buffer(2)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if (gid.x<p.size) {\n"
|
||||
" auto V=in[int(gid.x)];\n"
|
||||
" out[int(gid.x)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(p.minV);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
;
|
||||
const char* shader_MetalConvolutionDepthwise_metal =
|
||||
|
|
@ -59,7 +54,7 @@ const char* shader_MetalConvolutionDepthwise_metal =
|
|||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice*cst.batch) return;\n"
|
||||
" \n"
|
||||
" int oz=gid.z % cst.slice;\n"
|
||||
" int oz=gid.z/cst.batch;\n"
|
||||
" int offset_x=(int)gid.x*cst.stride_x-cst.pad_x;\n"
|
||||
" int offset_y=(int)gid.y*cst.stride_y-cst.pad_y;\n"
|
||||
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
|
||||
|
|
@ -198,16 +193,16 @@ const char* shader_MetalConvolution_metal =
|
|||
" offset_x += sx*cst.dilation_x;\n"
|
||||
" offset_y += sy*cst.dilation_y;\n"
|
||||
" \n"
|
||||
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
|
||||
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+(int)idx_c*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
|
||||
" auto z_out=out+idx_b*cst.output_size+(int)idx_c*cst.batch*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
|
||||
" int dilation_h=cst.input_width*cst.dilation_y;\n"
|
||||
" FLOAT4 result=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" for (auto y=0; y<kh; y++) {\n"
|
||||
" for (auto x=0; x<kw; x++) {\n"
|
||||
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
|
||||
" auto in4=z_in[z*cst.input_size+y*dilation_h+x*cst.dilation_x];\n"
|
||||
" auto in4=z_in[z*cst.input_size*cst.batch+y*dilation_h+x*cst.dilation_x];\n"
|
||||
" result += FLOAT4(in4*wt4);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
|
|
@ -231,14 +226,14 @@ const char* shader_MetalConvolution_metal =
|
|||
" bool valid_x=(int)(gid.x*2+1)<cst.output_width;\n"
|
||||
" int offset_x=(int)gid.x*2-cst.pad_x;\n"
|
||||
" int offset_y=(int)gid.y-cst.pad_y;\n"
|
||||
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_flt=wt+uz[0]*cst.input_slice*cst.kernel_size;\n"
|
||||
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" \n"
|
||||
" int ws=cst.input_slice*cst.kernel_size;\n"
|
||||
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
|
||||
" FLOAT4 result4=0,result5=0,result6=0,result7=0;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += cst.input_size) {\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += (cst.input_size*cst.batch)) {\n"
|
||||
" auto in00=(offset_x<0 || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+0);\n"
|
||||
" auto in01=(offset_x+1>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+1);\n"
|
||||
" auto in02=(offset_x+2>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+2);\n"
|
||||
|
|
@ -312,18 +307,18 @@ const char* shader_MetalConvolution_metal =
|
|||
" if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
|
||||
" bool valid=(idx_w+1<cst.output_width);\n"
|
||||
" \n"
|
||||
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+idx_h*cst.input_width+idx_w;\n"
|
||||
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
|
||||
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
|
||||
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" FLOAT4 result1=result0;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" for (auto y=0; y<cst.kernel_y; y++) {\n"
|
||||
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x];\n"
|
||||
" auto in4_0=z_in[z*cst.input_size+y*cst.input_width];\n"
|
||||
" auto in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width];\n"
|
||||
" result0 += FLOAT4(in4_0*wt4);\n"
|
||||
" for (auto x=1; x<cst.kernel_x; x++) {\n"
|
||||
" in4_0=z_in[z*cst.input_size+y*cst.input_width+x];\n"
|
||||
" in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width+x];\n"
|
||||
" result1 += FLOAT4(in4_0*wt4);\n"
|
||||
" wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
|
||||
" result0 += FLOAT4(in4_0*wt4);\n"
|
||||
|
|
@ -352,9 +347,9 @@ const char* shader_MetalConvolution_metal =
|
|||
" int3 uz=idx_w+int3(1,2,3);\n"
|
||||
" bool3 valids=uz.xyz<cst.output_width;\n"
|
||||
" \n"
|
||||
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+idx_h*cst.input_width+idx_w;\n"
|
||||
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
|
||||
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
|
||||
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" FLOAT4 result1=result0;\n"
|
||||
" FLOAT4 result2=result0;\n"
|
||||
|
|
@ -365,7 +360,7 @@ const char* shader_MetalConvolution_metal =
|
|||
" auto wt4_0=wt_base[0];\n"
|
||||
" auto wt4_1=wt_base[1];\n"
|
||||
" auto wt4_2=wt_base[2];\n"
|
||||
" auto z_in_base=z_in+z*cst.input_size+y*cst.input_width;\n"
|
||||
" auto z_in_base=z_in+z*cst.batch*cst.input_size+y*cst.input_width;\n"
|
||||
" auto in4_0=z_in_base[0];\n"
|
||||
" result0 += FLOAT4(in4_0*wt4_0);\n"
|
||||
" \n"
|
||||
|
|
@ -420,14 +415,14 @@ const char* shader_MetalConvolution_metal =
|
|||
" offset_x += sx*cst.dilation_x;\n"
|
||||
" offset_y += sy*cst.dilation_y;\n"
|
||||
" \n"
|
||||
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
|
||||
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" \n"
|
||||
" int ws=cst.input_slice*cst.kernel_size;\n"
|
||||
" int dilation_h=cst.input_width*cst.dilation_y;\n"
|
||||
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size) {\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
|
||||
" for (auto y=0; y<kh; y++) {\n"
|
||||
" for (auto x=0; x<kw; x++) {\n"
|
||||
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
|
||||
|
|
@ -471,14 +466,14 @@ const char* shader_MetalConvolution_metal =
|
|||
" offset_x += sx*cst.dilation_x;\n"
|
||||
" offset_y += sy*cst.dilation_y;\n"
|
||||
" \n"
|
||||
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
|
||||
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
|
||||
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" \n"
|
||||
" int ws=cst.input_slice*cst.kernel_size;\n"
|
||||
" int dilation_h=cst.input_width*cst.dilation_y;\n"
|
||||
" FLOAT4 result0=0,result1=0;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size) {\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
|
||||
" for (auto y=0; y<kh; y++) {\n"
|
||||
" for (auto x=0; x<kw; x++) {\n"
|
||||
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
|
||||
|
|
@ -489,7 +484,7 @@ const char* shader_MetalConvolution_metal =
|
|||
" }\n"
|
||||
" }\n"
|
||||
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
|
||||
" if (valids) { z_out += cst.output_size; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
|
||||
" if (valids) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
|
||||
"}\n"
|
||||
;
|
||||
const char* shader_MetalReduction_metal =
|
||||
|
|
@ -567,34 +562,9 @@ const char* shader_MetalBackend_metal =
|
|||
"struct tensor_shape {\n"
|
||||
" int size;\n"
|
||||
" int channel;\n"
|
||||
" int slice;\n"
|
||||
" int batch;\n"
|
||||
" int batch_slices;\n"
|
||||
"};\n"
|
||||
"kernel void version_func_002(const device uchar *in [[buffer(0)]],\n"
|
||||
" device uchar *out [[buffer(1)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" // do nothing,just for verifying match between mnn and metallib\n"
|
||||
"}\n"
|
||||
"kernel void copy_byte(const device uchar *in [[buffer(0)]],\n"
|
||||
" device uchar *out [[buffer(1)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" out[int(gid)]=in[int(gid)];\n"
|
||||
"}\n"
|
||||
"kernel void copy_float(const device M *in [[buffer(0)]],\n"
|
||||
" device M *out [[buffer(1)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" out[int(gid)]=in[int(gid)];\n"
|
||||
"}\n"
|
||||
"kernel void upcast_float(const device M *in [[buffer(0)]],\n"
|
||||
" device float *out [[buffer(1)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" out[int(gid)]=in[int(gid)];\n"
|
||||
"}\n"
|
||||
"kernel void downcast_float(const device float *in [[buffer(0)]],\n"
|
||||
" device M *out [[buffer(1)]],\n"
|
||||
" uint gid [[thread_position_in_grid]]) {\n"
|
||||
" out[int(gid)]=in[int(gid)];\n"
|
||||
"}\n"
|
||||
"struct Limit {\n"
|
||||
" uint4 size;\n"
|
||||
"};\n"
|
||||
|
|
@ -616,8 +586,8 @@ const char* shader_MetalBackend_metal =
|
|||
"}\n"
|
||||
"template <typename IType,typename OType>\n"
|
||||
"static inline void template_NHWC_to_NC4HW4(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
|
||||
" int b=gid.y/s.slice;\n"
|
||||
" int z=gid.y % s.slice;\n"
|
||||
" int b=gid.y % s.batch;\n"
|
||||
" int z=gid.y/s.batch;\n"
|
||||
" int c=z*4;\n"
|
||||
" \n"
|
||||
" auto off_in=in+b*s.size*s.channel+int(gid.x)*s.channel+c;\n"
|
||||
|
|
@ -653,8 +623,8 @@ const char* shader_MetalBackend_metal =
|
|||
"}\n"
|
||||
"template <typename IType,typename OType>\n"
|
||||
"static inline void template_NC4HW4_to_NHWC(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
|
||||
" int b=gid.y/s.slice;\n"
|
||||
" int z=gid.y % s.slice;\n"
|
||||
" int b=gid.y % s.batch;\n"
|
||||
" int z=gid.y/s.batch;\n"
|
||||
" int c=z*4;\n"
|
||||
" auto off_in=in+int(gid.y)*s.size+int(gid.x);\n"
|
||||
" auto off_out=out+b*s.size*s.channel+int(gid.x)*s.channel+c;\n"
|
||||
|
|
@ -691,8 +661,8 @@ const char* shader_MetalBackend_metal =
|
|||
"}\n"
|
||||
"template <typename IType,typename OType>\n"
|
||||
"static inline void template_NCHW_to_NC4HW4(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
|
||||
" int b=gid.y/s.slice;\n"
|
||||
" int z=gid.y % s.slice;\n"
|
||||
" int b=gid.y % s.batch;\n"
|
||||
" int z=gid.y/s.batch;\n"
|
||||
" int c=z*4;\n"
|
||||
" \n"
|
||||
" auto off_in=in+(b*s.channel+c)*s.size+int(gid.x);\n"
|
||||
|
|
@ -728,8 +698,8 @@ const char* shader_MetalBackend_metal =
|
|||
"}\n"
|
||||
"template <typename IType,typename OType>\n"
|
||||
"static inline void template_NC4HW4_to_NCHW(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
|
||||
" int b=gid.y/s.slice;\n"
|
||||
" int z=gid.y % s.slice;\n"
|
||||
" int b=gid.y % s.batch;\n"
|
||||
" int z=gid.y/s.batch;\n"
|
||||
" int c=z*4;\n"
|
||||
" \n"
|
||||
" auto off_in=in+int(gid.y)*s.size+int(gid.x);\n"
|
||||
|
|
@ -767,8 +737,8 @@ const char* shader_MetalBackend_metal =
|
|||
"template<typename IType,typename OType>\n"
|
||||
"static inline void template_NHWC_to_NCHW(const device IType* in,\n"
|
||||
" device OType* out,constant tensor_shape &s,uint2 gid) {\n"
|
||||
" int b=gid.y/s.slice;\n"
|
||||
" int c4=gid.y % s.slice;\n"
|
||||
" int b=gid.y % s.batch;\n"
|
||||
" int c4=gid.y/s.batch;\n"
|
||||
" \n"
|
||||
" auto in_off=(b*s.size+gid.x)*s.channel+c4*4;\n"
|
||||
" auto out_off=(b*s.channel+c4*4)*s.size+gid.x;\n"
|
||||
|
|
@ -793,8 +763,8 @@ const char* shader_MetalBackend_metal =
|
|||
"template<typename IType,typename OType>\n"
|
||||
"static inline void template_NCHW_to_NHWC(const device IType* in,\n"
|
||||
" device OType* out,constant tensor_shape &s,uint2 gid) {\n"
|
||||
" int b=gid.y/s.slice;\n"
|
||||
" int c4=gid.y % s.slice;\n"
|
||||
" int b=gid.y % s.batch;\n"
|
||||
" int c4=gid.y/s.batch;\n"
|
||||
" \n"
|
||||
" auto in_off=(b*s.channel+c4*4)*s.size+gid.x;\n"
|
||||
" auto out_off=(b*s.size+gid.x)*s.channel+c4*4;\n"
|
||||
|
|
@ -1452,7 +1422,7 @@ const char* shader_MetalScale_metal =
|
|||
" const device float4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint2 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x >= s.size || (int)gid.y >= s.steps*s.batch) return;\n"
|
||||
" int z=gid.y % s.steps;\n"
|
||||
" int z=gid.y/s.batch;\n"
|
||||
" out[int(gid.y)*s.size+int(gid.x)] =\n"
|
||||
" in [int(gid.y)*s.size+int(gid.x)]*M4(scales[z])+M4(biasTerms[z]);\n"
|
||||
"}\n"
|
||||
|
|
@ -1494,8 +1464,8 @@ const char* shader_MetalDeconvolution_metal =
|
|||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
|
||||
" \n"
|
||||
" int b=gid.z/cst.output_slice;\n"
|
||||
" int o=gid.z % cst.output_slice;\n"
|
||||
" int b=gid.z % cst.batch;\n"
|
||||
" int o=gid.z/cst.batch;\n"
|
||||
" FLOAT4 result=FLOAT4(biasTerms[o]);\n"
|
||||
" int oy=(int)gid.y+cst.pad_y;\n"
|
||||
" int ox=(int)gid.x+cst.pad_x;\n"
|
||||
|
|
@ -1513,12 +1483,12 @@ const char* shader_MetalDeconvolution_metal =
|
|||
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
|
||||
" \n"
|
||||
" auto o_wt=wt+o*cst.input_slice*cst.kernel_size;\n"
|
||||
" auto b_in=in+b*cst.input_slice*cst.input_size;\n"
|
||||
" auto b_in=in+b*cst.input_size;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
|
||||
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
|
||||
" auto wt4=o_wt[z*cst.kernel_size+ky*cst.kernel_x+kx];\n"
|
||||
" auto in4=b_in[z*cst.input_size+iy*cst.input_width+ix];\n"
|
||||
" auto in4=b_in[z*cst.input_size*cst.batch+iy*cst.input_width+ix];\n"
|
||||
" result += FLOAT4(in4*wt4);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
|
|
@ -1534,7 +1504,7 @@ const char* shader_MetalDeconvolution_metal =
|
|||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
|
||||
" \n"
|
||||
" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z % cst.input_slice)]);\n"
|
||||
" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z/cst.batch)]);\n"
|
||||
" \n"
|
||||
" int oy=(int)gid.y+cst.pad_y;\n"
|
||||
" int ox=(int)gid.x+cst.pad_x;\n"
|
||||
|
|
@ -1631,10 +1601,11 @@ const char* shader_MetalROIPooling_metal =
|
|||
" int input_width;\n"
|
||||
" int input_height;\n"
|
||||
" int input_size;\n"
|
||||
" int input_batch;\n"
|
||||
" int output_width;\n"
|
||||
" int output_height;\n"
|
||||
" int output_size;\n"
|
||||
" int slices;\n"
|
||||
" int batch;\n"
|
||||
" float spatial_scale;\n"
|
||||
"};\n"
|
||||
"kernel void ROI_pooling(const device M4 *in [[buffer(0)]],\n"
|
||||
|
|
@ -1644,10 +1615,10 @@ const char* shader_MetalROIPooling_metal =
|
|||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;\n"
|
||||
" \n"
|
||||
" int ob=gid.z/s.slices;\n"
|
||||
" int iz=gid.z % s.slices;\n"
|
||||
" int ob=gid.z % s.batch;\n"
|
||||
" int iz=gid.z/s.batch;\n"
|
||||
" \n"
|
||||
" auto b_roi=roi+ob*8; // roundup(5,4)=8\n"
|
||||
" auto b_roi=roi+ob*5;\n"
|
||||
" int ib=int(b_roi[0]);\n"
|
||||
" int x1=round(float(b_roi[1])*s.spatial_scale);\n"
|
||||
" int y1=round(float(b_roi[2])*s.spatial_scale);\n"
|
||||
|
|
@ -1656,8 +1627,8 @@ const char* shader_MetalROIPooling_metal =
|
|||
" \n"
|
||||
" int roi_w=max(x2-x1+1,1);\n"
|
||||
" int roi_h=max(y2-y1+1,1);\n"
|
||||
" auto bin_size_w=(M)roi_w/s.output_width;\n"
|
||||
" auto bin_size_h=(M)roi_h/s.output_height;\n"
|
||||
" float bin_size_w=(float)roi_w/(float)s.output_width;\n"
|
||||
" float bin_size_h=(float)roi_h/(float)s.output_height;\n"
|
||||
" \n"
|
||||
" int w_start=clamp(x1+(int)floor(gid.x*bin_size_w) ,0,s.input_width);\n"
|
||||
" int w_end=clamp(x1+(int)ceil((gid.x+1)*bin_size_w),0,s.input_width);\n"
|
||||
|
|
@ -1665,7 +1636,7 @@ const char* shader_MetalROIPooling_metal =
|
|||
" int h_end=clamp(y1+(int)ceil((gid.y+1)*bin_size_h),0,s.input_height);\n"
|
||||
" \n"
|
||||
" int is_empty=(h_end <= h_start) || (w_end <= w_start);\n"
|
||||
" auto z_in=in+(ib*s.slices+iz)*s.input_size;\n"
|
||||
" auto z_in=in+(ib+iz*s.input_batch)*s.input_size;\n"
|
||||
" auto max4=is_empty ? 0 : z_in[h_start*s.input_width+w_start];\n"
|
||||
" for (int y=h_start; y<h_end; y++) {\n"
|
||||
" auto y_in=z_in+y*s.input_width;\n"
|
||||
|
|
@ -1690,30 +1661,6 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" int batch;\n"
|
||||
" conv_activation_type activation;\n"
|
||||
"};\n"
|
||||
"kernel void conv1x1_w1h1(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant conv1x1_constants& cst [[buffer(2)]],\n"
|
||||
" const device M4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
|
||||
" int idx_w=gid.x;\n"
|
||||
" int idx_h=gid.y;\n"
|
||||
" int idx_c=gid.z % cst.output_slice;\n"
|
||||
" int idx_b=gid.z/cst.output_slice;\n"
|
||||
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" FLOAT4 result0=biasValue;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" auto in40=xy_in0[0];\n"
|
||||
" auto w=xy_wt[z];\n"
|
||||
" result0 += FLOAT4(in40*w);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" }\n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
"}\n"
|
||||
"kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant conv1x1_constants& cst [[buffer(2)]],\n"
|
||||
|
|
@ -1725,8 +1672,8 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" int rx=gid.x*CONV_UNROLL;\n"
|
||||
" int uz=gid.y;\n"
|
||||
" auto xy_wt=wt+uz*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
|
||||
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
|
||||
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
|
||||
|
|
@ -1741,7 +1688,7 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" result1 += FLOAT4(in41*w);\n"
|
||||
" result2 += FLOAT4(in42*w);\n"
|
||||
" result3 += FLOAT4(in43*w);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
|
|
@ -1755,19 +1702,18 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" const device MNN::char4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" const device float4 *dequantScale [[buffer(5)]],\n"
|
||||
" const device float4 *dequantBias [[buffer(6)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
|
||||
" int rx=gid.x*CONV_UNROLL;\n"
|
||||
" int uz=gid.y;\n"
|
||||
" auto xy_wt=wt+uz*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
|
||||
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
|
||||
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
|
||||
" auto scale=FLOAT4(dequantScale[uz]);\n"
|
||||
" auto dequant_bias=FLOAT4(dequantBias[uz]);\n"
|
||||
" auto dequant_bias=FLOAT4(dequantScale[uz+cst.output_slice]);\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" auto in40=(FLOAT4)*xy_in0;\n"
|
||||
" auto in41=(FLOAT4)*(xy_in0+1);\n"
|
||||
|
|
@ -1779,20 +1725,14 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n"
|
||||
" FLOAT4x4 w_dequant;\n"
|
||||
" for (int i=0; i<4; ++i) {\n"
|
||||
" FLOAT4 w4=w_fp32[i];\n"
|
||||
" FLOAT4 res;\n"
|
||||
" for (int j=0; j<4; ++j) {\n"
|
||||
" float wf=w4[j]*scale[i]+dequant_bias[i];\n"
|
||||
" res[j]=wf;\n"
|
||||
" }\n"
|
||||
" w_dequant[i]=res;\n"
|
||||
" w_dequant[i]=w_fp32[i]*scale[i]+dequant_bias[i];\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" result0 += FLOAT4(in40*w_dequant);\n"
|
||||
" result1 += FLOAT4(in41*w_dequant);\n"
|
||||
" result2 += FLOAT4(in42*w_dequant);\n"
|
||||
" result3 += FLOAT4(in43*w_dequant);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" /* true */ \n"
|
||||
|
|
@ -1807,19 +1747,18 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" const device float4 *dequantScale [[buffer(5)]],\n"
|
||||
" const device float4 *dequantBias [[buffer(6)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
|
||||
" int rx=gid.x*CONV_UNROLL;\n"
|
||||
" int uz=gid.y;\n"
|
||||
" auto xy_wt=wt+uz*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
|
||||
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
|
||||
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
|
||||
" auto scale=FLOAT4(dequantScale[uz]);\n"
|
||||
" auto dequant_bias=FLOAT4(dequantBias[uz]);\n"
|
||||
" auto dequant_bias=FLOAT4(dequantScale[uz+cst.output_slice]);\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" auto in40=(FLOAT4)*xy_in0;\n"
|
||||
" auto in41=(FLOAT4)*(xy_in0+1);\n"
|
||||
|
|
@ -1833,11 +1772,7 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" for (int i=0; i<4; ++i) {\n"
|
||||
" // M4 w4=M4(w_fp32[i]);\n"
|
||||
" FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n"
|
||||
" FLOAT4 res;\n"
|
||||
" for (int j=0; j<4; ++j) {\n"
|
||||
" float wf=w4[j]*scale[i]+dequant_bias[i];\n"
|
||||
" res[j]=wf;\n"
|
||||
" }\n"
|
||||
" FLOAT4 res=w4*scale[i]+dequant_bias[i];\n"
|
||||
" w_dequant[i]=res;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
|
|
@ -1845,7 +1780,7 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" result1 += FLOAT4(in41*w_dequant);\n"
|
||||
" result2 += FLOAT4(in42*w_dequant);\n"
|
||||
" result3 += FLOAT4(in43*w_dequant);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" \n"
|
||||
" /* true */ \n"
|
||||
|
|
@ -1874,8 +1809,8 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" int rx=gid.x*CONV_UNROLL_L;\n"
|
||||
" int uz=gid.y;\n"
|
||||
" auto xy_wt=wt+uz*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
|
||||
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
|
||||
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.batch*cst.output_size+rx;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
|
||||
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
|
||||
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
|
||||
|
|
@ -1898,7 +1833,7 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" result5 += FLOAT4(in45*w);\n"
|
||||
" result6 += FLOAT4(in46*w);\n"
|
||||
" result7 += FLOAT4(in47*w);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
|
||||
|
|
@ -1909,71 +1844,20 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" if (computeSize>6) {xy_out[6]=activate(M4(result6),cst.activation); }\n"
|
||||
" if (computeSize>7) {xy_out[7]=activate(M4(result7),cst.activation); }\n"
|
||||
"}\n"
|
||||
"kernel void conv1x1_w4h2(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant conv1x1_constants& cst [[buffer(2)]],\n"
|
||||
" const device M4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
|
||||
" int idx_w=gid.x << 2;\n"
|
||||
" int idx_h=gid.y << 1;\n"
|
||||
" int idx_c=gid.z % cst.output_slice;\n"
|
||||
" int idx_b=gid.z/cst.output_slice;\n"
|
||||
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
|
||||
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" auto in40=xy_in0[0];\n"
|
||||
" auto in41=xy_in0[1];\n"
|
||||
" auto in42=xy_in0[2];\n"
|
||||
" auto in43=xy_in0[3];\n"
|
||||
" auto in44=xy_in0[cst.output_width+0];\n"
|
||||
" auto in45=xy_in0[cst.output_width+1];\n"
|
||||
" auto in46=xy_in0[cst.output_width+2];\n"
|
||||
" auto in47=xy_in0[cst.output_width+3];\n"
|
||||
" auto w=xy_wt[z];\n"
|
||||
" result0 += FLOAT4(in40*w);\n"
|
||||
" result1 += FLOAT4(in41*w);\n"
|
||||
" result2 += FLOAT4(in42*w);\n"
|
||||
" result3 += FLOAT4(in43*w);\n"
|
||||
" result4 += FLOAT4(in44*w);\n"
|
||||
" result5 += FLOAT4(in45*w);\n"
|
||||
" result6 += FLOAT4(in46*w);\n"
|
||||
" result7 += FLOAT4(in47*w);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" }\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,4);\n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
|
||||
" \n"
|
||||
" int heightSize=min(cst.output_height-idx_h,2);\n"
|
||||
" if(heightSize>1) {\n"
|
||||
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[cst.output_width+2]=activate(M4(result6),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[cst.output_width+3]=activate(M4(result7),cst.activation); }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"kernel void conv1x1_w4h4(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant conv1x1_constants& cst [[buffer(2)]],\n"
|
||||
" const device M4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*4 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
|
||||
" int idx_w=gid.x << 2;\n"
|
||||
" int idx_h=gid.y << 2;\n"
|
||||
" int idx_c=gid.z % cst.output_slice;\n"
|
||||
" int idx_b=gid.z/cst.output_slice;\n"
|
||||
" if ((int)gid.x*16 >= cst.output_width || (int)gid.y >= cst.batch*cst.output_slice) return;\n"
|
||||
" int idx_w=gid.x << 4;\n"
|
||||
" int idx_h=0;\n"
|
||||
" int idx_c=gid.y/cst.batch;\n"
|
||||
" int idx_b=gid.y % cst.batch;\n"
|
||||
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" FLOAT4 result00=biasValue,result01=biasValue,result02=biasValue,result03=biasValue;\n"
|
||||
" FLOAT4 result10=biasValue,result11=biasValue,result12=biasValue,result13=biasValue;\n"
|
||||
|
|
@ -1984,19 +1868,19 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" auto in01=xy_in0[1];\n"
|
||||
" auto in02=xy_in0[2];\n"
|
||||
" auto in03=xy_in0[3];\n"
|
||||
" auto in10=xy_in0[cst.output_width+0];\n"
|
||||
" auto in11=xy_in0[cst.output_width+1];\n"
|
||||
" auto in12=xy_in0[cst.output_width+2];\n"
|
||||
" auto in13=xy_in0[cst.output_width+3];\n"
|
||||
" auto in10=xy_in0[4];\n"
|
||||
" auto in11=xy_in0[5];\n"
|
||||
" auto in12=xy_in0[6];\n"
|
||||
" auto in13=xy_in0[7];\n"
|
||||
" \n"
|
||||
" auto in20=xy_in0[cst.output_width+cst.output_width+0];\n"
|
||||
" auto in21=xy_in0[cst.output_width+cst.output_width+1];\n"
|
||||
" auto in22=xy_in0[cst.output_width+cst.output_width+2];\n"
|
||||
" auto in23=xy_in0[cst.output_width+cst.output_width+3];\n"
|
||||
" auto in30=xy_in0[cst.output_width+cst.output_width+cst.output_width+0];\n"
|
||||
" auto in31=xy_in0[cst.output_width+cst.output_width+cst.output_width+1];\n"
|
||||
" auto in32=xy_in0[cst.output_width+cst.output_width+cst.output_width+2];\n"
|
||||
" auto in33=xy_in0[cst.output_width+cst.output_width+cst.output_width+3];\n"
|
||||
" auto in20=xy_in0[8];\n"
|
||||
" auto in21=xy_in0[9];\n"
|
||||
" auto in22=xy_in0[10];\n"
|
||||
" auto in23=xy_in0[11];\n"
|
||||
" auto in30=xy_in0[12];\n"
|
||||
" auto in31=xy_in0[13];\n"
|
||||
" auto in32=xy_in0[14];\n"
|
||||
" auto in33=xy_in0[15];\n"
|
||||
" auto w=xy_wt[z];\n"
|
||||
" result00 += FLOAT4(in00*w);\n"
|
||||
" result01 += FLOAT4(in01*w);\n"
|
||||
|
|
@ -2016,33 +1900,25 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" result32 += FLOAT4(in32*w);\n"
|
||||
" result33 += FLOAT4(in33*w);\n"
|
||||
" \n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,4);\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,16);\n"
|
||||
" /* true */ *xy_out=activate(M4(result00),cst.activation);\n"
|
||||
" if (widthSize>1) {xy_out[1]=activate(M4(result01),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[2]=activate(M4(result02),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[3]=activate(M4(result03),cst.activation); }\n"
|
||||
" \n"
|
||||
" int heightSize=min(cst.output_height-idx_h,4);\n"
|
||||
" if(heightSize>1) {\n"
|
||||
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result10),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result11),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[cst.output_width+2]=activate(M4(result12),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[cst.output_width+3]=activate(M4(result13),cst.activation); }\n"
|
||||
" }\n"
|
||||
" if(heightSize>2) {\n"
|
||||
" /* true */ {xy_out[cst.output_width+cst.output_width+0]=activate(M4(result20),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_width+cst.output_width+1]=activate(M4(result21),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[cst.output_width+cst.output_width+2]=activate(M4(result22),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[cst.output_width+cst.output_width+3]=activate(M4(result23),cst.activation); }\n"
|
||||
" }\n"
|
||||
" if(heightSize>3) {\n"
|
||||
" /* true */ {xy_out[cst.output_width+cst.output_width+cst.output_width+0]=activate(M4(result30),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_width+cst.output_width+cst.output_width+1]=activate(M4(result31),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[cst.output_width+cst.output_width+cst.output_width+2]=activate(M4(result32),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[cst.output_width+cst.output_width+cst.output_width+3]=activate(M4(result33),cst.activation); }\n"
|
||||
" }\n"
|
||||
" if (widthSize>4) {xy_out[4]=activate(M4(result10),cst.activation); }\n"
|
||||
" if (widthSize>5) {xy_out[5]=activate(M4(result11),cst.activation); }\n"
|
||||
" if (widthSize>6) {xy_out[6]=activate(M4(result12),cst.activation); }\n"
|
||||
" if (widthSize>7) {xy_out[7]=activate(M4(result13),cst.activation); }\n"
|
||||
" if (widthSize>8) {xy_out[8]=activate(M4(result20),cst.activation); }\n"
|
||||
" if (widthSize>9) {xy_out[9]=activate(M4(result21),cst.activation); }\n"
|
||||
" if (widthSize>10) {xy_out[10]=activate(M4(result22),cst.activation); }\n"
|
||||
" if (widthSize>11) {xy_out[11]=activate(M4(result23),cst.activation); }\n"
|
||||
" if (widthSize>12) {xy_out[12]=activate(M4(result30),cst.activation); }\n"
|
||||
" if (widthSize>13) {xy_out[13]=activate(M4(result31),cst.activation); }\n"
|
||||
" if (widthSize>14) {xy_out[14]=activate(M4(result32),cst.activation); }\n"
|
||||
" if (widthSize>15) {xy_out[15]=activate(M4(result33),cst.activation); }\n"
|
||||
"}\n"
|
||||
"kernel void conv1x1_w2c2(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
|
|
@ -2050,17 +1926,17 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" const device M4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z*2 >= cst.batch*cst.output_slice) return;\n"
|
||||
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
|
||||
" int channel_pack=(cst.output_channel+7) >> 3;\n"
|
||||
" int idx_w=gid.x << 1;\n"
|
||||
" int idx_h=gid.y;\n"
|
||||
" int idx_c=(gid.z % channel_pack) << 1;\n"
|
||||
" int idx_b=gid.z/channel_pack;\n"
|
||||
" int idx_h=0;\n"
|
||||
" int idx_c=(gid.y % channel_pack) << 1;\n"
|
||||
" int idx_b=gid.y/channel_pack;\n"
|
||||
" \n"
|
||||
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
|
||||
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
|
||||
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
|
||||
|
|
@ -2074,7 +1950,7 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" result1 += FLOAT4(in41*w0);\n"
|
||||
" result4 += FLOAT4(in40*w1);\n"
|
||||
" result5 += FLOAT4(in41*w1);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,2);\n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
|
|
@ -2082,65 +1958,26 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" \n"
|
||||
" int channelSize=min(cst.output_slice-idx_c,2);\n"
|
||||
" if(channelSize>1) {\n"
|
||||
" /* true */ {xy_out[cst.output_size+0]=activate(M4(result4),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_size+1]=activate(M4(result5),cst.activation); }\n"
|
||||
" /* true */ {xy_out[cst.output_size*cst.batch +0]=activate(M4(result4),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result5),cst.activation); }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"kernel void conv1x1_w2h2(const device M4 *in [[buffer(0)]],\n"
|
||||
"kernel void conv1x1_w4c2(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant conv1x1_constants& cst [[buffer(2)]],\n"
|
||||
" const device M4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
|
||||
" int idx_w=gid.x << 1;\n"
|
||||
" int idx_h=gid.y << 1;\n"
|
||||
" int idx_c=gid.z % cst.output_slice;\n"
|
||||
" int idx_b=gid.z/cst.output_slice;\n"
|
||||
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" FLOAT4 result0=biasValue,result1=biasValue;\n"
|
||||
" FLOAT4 result4=biasValue,result5=biasValue;\n"
|
||||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" auto in40=xy_in0[0];\n"
|
||||
" auto in41=xy_in0[1];\n"
|
||||
" auto in44=xy_in0[cst.output_width+0];\n"
|
||||
" auto in45=xy_in0[cst.output_width+1];\n"
|
||||
" auto w=xy_wt[z];\n"
|
||||
" result0 += FLOAT4(in40*w);\n"
|
||||
" result1 += FLOAT4(in41*w);\n"
|
||||
" result4 += FLOAT4(in44*w);\n"
|
||||
" result5 += FLOAT4(in45*w);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" }\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,2);\n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
|
||||
" \n"
|
||||
" int heightSize=min(cst.output_height-idx_h,2);\n"
|
||||
" if(heightSize>1) {\n"
|
||||
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"kernel void conv1x1_w2h2c2(const device M4 *in [[buffer(0)]],\n"
|
||||
" device M4 *out [[buffer(1)]],\n"
|
||||
" constant conv1x1_constants& cst [[buffer(2)]],\n"
|
||||
" const device M4x4 *wt [[buffer(3)]],\n"
|
||||
" const device M4 *biasTerms [[buffer(4)]],\n"
|
||||
" uint3 gid [[thread_position_in_grid]]) {\n"
|
||||
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z*2 >= cst.batch*cst.output_slice) return;\n"
|
||||
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
|
||||
" int channel_pack=(cst.output_channel+7) >> 3;\n"
|
||||
" int idx_w=gid.x << 1;\n"
|
||||
" int idx_h=gid.y << 1;\n"
|
||||
" int idx_c=(gid.z % channel_pack) << 1;\n"
|
||||
" int idx_b=gid.z/channel_pack;\n"
|
||||
" int idx_w=gid.x << 2;\n"
|
||||
" int idx_h=0;\n"
|
||||
" int idx_c=(gid.y % channel_pack) << 1;\n"
|
||||
" int idx_b=gid.y/channel_pack;\n"
|
||||
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
|
||||
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
|
||||
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
|
||||
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
|
||||
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
|
||||
|
|
@ -2150,8 +1987,8 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" for (auto z=0; z<cst.input_slice; z++) {\n"
|
||||
" auto in40=xy_in0[0];\n"
|
||||
" auto in41=xy_in0[1];\n"
|
||||
" auto in44=xy_in0[cst.output_width+0];\n"
|
||||
" auto in45=xy_in0[cst.output_width+1];\n"
|
||||
" auto in44=xy_in0[2];\n"
|
||||
" auto in45=xy_in0[3];\n"
|
||||
" auto w0=xy_wt[z];\n"
|
||||
" auto w1=xy_wt[cst.input_slice+z];\n"
|
||||
" result0 += FLOAT4(in40*w0);\n"
|
||||
|
|
@ -2162,27 +1999,20 @@ const char* shader_MetalConvolution1x1_metal =
|
|||
" result3 += FLOAT4(in41*w1);\n"
|
||||
" result6 += FLOAT4(in44*w1);\n"
|
||||
" result7 += FLOAT4(in45*w1);\n"
|
||||
" xy_in0 += cst.input_size;\n"
|
||||
" xy_in0 += cst.input_size*cst.batch;\n"
|
||||
" }\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,2);\n"
|
||||
" int widthSize=min(cst.output_width-idx_w,4);\n"
|
||||
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
|
||||
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
|
||||
" \n"
|
||||
" int heightSize=min(cst.output_height-idx_h,2);\n"
|
||||
" if(heightSize>1) {\n"
|
||||
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
|
||||
" }\n"
|
||||
" if (widthSize>2) {xy_out[2]=activate(M4(result4),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[3]=activate(M4(result5),cst.activation); }\n"
|
||||
" \n"
|
||||
" int channelSize=min(cst.output_slice-idx_c,2);\n"
|
||||
" if(channelSize>1) {\n"
|
||||
" /* true */ xy_out[cst.output_size]=activate(M4(result2),cst.activation);\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_size+1]=activate(M4(result3),cst.activation); }\n"
|
||||
" \n"
|
||||
" if(heightSize>1) {\n"
|
||||
" /* true */ {xy_out[cst.output_size+cst.output_width+0]=activate(M4(result6),cst.activation); }\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_size+cst.output_width+1]=activate(M4(result7),cst.activation); }\n"
|
||||
" }\n"
|
||||
" /* true */ xy_out[cst.output_size*cst.batch]=activate(M4(result2),cst.activation);\n"
|
||||
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result3),cst.activation); }\n"
|
||||
" if (widthSize>2) {xy_out[cst.output_size*cst.batch +2]=activate(M4(result6),cst.activation); }\n"
|
||||
" if (widthSize>3) {xy_out[cst.output_size*cst.batch +3]=activate(M4(result7),cst.activation); }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
;
|
||||
|
|
@ -2341,7 +2171,7 @@ const char* shader_MetalPReLU_metal =
|
|||
" uint3 gid [[thread_position_in_grid]]) { // size,slice,batch\n"
|
||||
" if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;\n"
|
||||
" \n"
|
||||
" int z=gid.z*s.slice+gid.y;\n"
|
||||
" int z=gid.z+gid.y*s.batch;\n"
|
||||
" auto v4=in[z*s.size+int(gid.x)];\n"
|
||||
" out[z*s.size+int(gid.x)]=select(v4,M4(slope[int(gid.y)])*v4,signbit(v4));\n"
|
||||
"}\n"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#ifndef MNN_METAL_SHADER_AUTO_GENERATE_H
|
||||
#define MNN_METAL_SHADER_AUTO_GENERATE_H
|
||||
extern const char* shader_MetalReLU6_metal;
|
||||
extern const char* shader_MetalReLU_metal;
|
||||
extern const char* shader_MetalConvolutionDepthwise_metal;
|
||||
extern const char* shader_MetalConvolutionActivation_metal;
|
||||
extern const char* shader_MetalConvolution_metal;
|
||||
|
|
|
|||
|
|
@ -27,14 +27,6 @@ typedef enum {
|
|||
CPUTransparent
|
||||
} MetalAccess;
|
||||
|
||||
typedef struct {
|
||||
/** wrap size */
|
||||
NSUInteger threadExecutionWidth;
|
||||
/** max threads per thread group */
|
||||
NSUInteger maxThreadsPerThreadgroup;
|
||||
/** run concurrently on z axis or not */
|
||||
BOOL zAxisProtected;
|
||||
} MetalBandwidth;
|
||||
}
|
||||
|
||||
@interface MNNMetalContext : NSObject
|
||||
|
|
@ -62,14 +54,6 @@ typedef struct {
|
|||
- (id<MTLBuffer>)newDeviceBuffer:(NSUInteger)size bytes:(const void *)bytes access:(MNN::MetalAccess)access;
|
||||
|
||||
|
||||
/**
|
||||
* @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline.
|
||||
* @param name pipline name
|
||||
* @param encoder command encoder
|
||||
* @return bandwidth info for function
|
||||
*/
|
||||
- (MNN::MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder fp16:(BOOL)fp16;
|
||||
|
||||
/**
|
||||
* @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline.
|
||||
* @param name pipline name
|
||||
|
|
@ -86,16 +70,6 @@ typedef struct {
|
|||
|
||||
- (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device;
|
||||
|
||||
/**
|
||||
* @brief dispatch encoder with default settings
|
||||
* @param encoder command encoder
|
||||
* @param threads threads size
|
||||
* @param bandwidth bandwidth
|
||||
*/
|
||||
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
|
||||
threads:(MTLSize)threads
|
||||
bandwidth:(MNN::MetalBandwidth)bandwidth;
|
||||
|
||||
/**
|
||||
* @brief dispatch encoder with designated threads per threadgroup
|
||||
* @param encoder command encoder
|
||||
|
|
@ -103,10 +77,6 @@ typedef struct {
|
|||
* @param threadsPerGroup thread size per group
|
||||
* @param bandwidth bandwidth
|
||||
*/
|
||||
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
|
||||
threads:(MTLSize)threads
|
||||
threadsPerGroup:(MTLSize)threadsPerGroup
|
||||
bandwidth:(MNN::MetalBandwidth)bandwidth;
|
||||
- (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name fp16:(BOOL)fp16;
|
||||
- (id<MTLComputePipelineState>)pipelineWithSourceOption:(NSString *)source name:(NSString *)name options:(MTLCompileOptions *)options;
|
||||
- (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) pipeline threads:(MTLSize)threads;
|
||||
|
|
|
|||
|
|
@ -181,25 +181,6 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
|
|||
return result;
|
||||
}
|
||||
|
||||
- (MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder fp16:(BOOL)fp16 {
|
||||
id<MTLComputePipelineState> pipeline = [self pipelineWithName:name fp16:fp16];
|
||||
MNN_ASSERT(nil != pipeline);
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
|
||||
if (!name) {
|
||||
} else if (!encoder.label) {
|
||||
encoder.label = name;
|
||||
} else {
|
||||
NSArray *components = [encoder.label componentsSeparatedByString:@","];
|
||||
if (![components containsObject:name]) {
|
||||
components = [components arrayByAddingObject:name];
|
||||
}
|
||||
encoder.label = [components componentsJoinedByString:@","];
|
||||
}
|
||||
#endif
|
||||
return {pipeline.threadExecutionWidth, pipeline.maxTotalThreadsPerThreadgroup, NO};
|
||||
}
|
||||
|
||||
- (NSUInteger)timeUsed:(id<MTLCommandBuffer>)buffer {
|
||||
// Get ns precision time
|
||||
auto start = mach_absolute_time();
|
||||
|
|
@ -393,37 +374,6 @@ static NSUInteger smallest_log2(NSUInteger integer) {
|
|||
}
|
||||
|
||||
- (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) bw threads:(MTLSize)t {
|
||||
if (bw.maxTotalThreadsPerThreadgroup > 64) {
|
||||
auto res = MTLSizeMake(8, 8, 8);
|
||||
int reduceNumber = 0;
|
||||
if (t.depth < 4) {
|
||||
res.depth = 1;
|
||||
reduceNumber++;
|
||||
}
|
||||
if (t.width < 4) {
|
||||
res.width = 1;
|
||||
reduceNumber++;
|
||||
}
|
||||
if (t.height < 4) {
|
||||
res.height = 1;
|
||||
reduceNumber++;
|
||||
}
|
||||
if (reduceNumber == 0) {
|
||||
return MTLSizeMake(4, 4, 4);
|
||||
}
|
||||
if (reduceNumber == 2) {
|
||||
if (res.width > 1) {
|
||||
res.width = 64;
|
||||
}
|
||||
if (res.height > 1) {
|
||||
res.height = 64;
|
||||
}
|
||||
if (res.depth > 1) {
|
||||
res.depth = 64;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
auto pwarp = smallest_log2(bw.threadExecutionWidth);
|
||||
auto px = smallest_log2(t.width), sx = (NSUInteger)ceil(log2(t.width));
|
||||
auto py = smallest_log2(t.height), sy = (NSUInteger)ceil(log2(t.height));
|
||||
|
|
@ -463,91 +413,6 @@ static NSUInteger smallest_log2(NSUInteger integer) {
|
|||
|
||||
}
|
||||
|
||||
- (MTLSize)threadsPerGroupWithThreads:(MTLSize)t bandwidth:(MetalBandwidth)bw {
|
||||
auto pwarp = smallest_log2(bw.threadExecutionWidth);
|
||||
auto px = smallest_log2(t.width), sx = (NSUInteger)ceil(log2(t.width));
|
||||
auto py = smallest_log2(t.height), sy = (NSUInteger)ceil(log2(t.height));
|
||||
|
||||
// accurately match on x
|
||||
if (px >= pwarp) {
|
||||
return {bw.threadExecutionWidth, 1, 1};
|
||||
}
|
||||
// accurately match on xy
|
||||
else if (px + py >= pwarp && sx < pwarp / 2) {
|
||||
NSUInteger x = pow(2, px);
|
||||
return {x, bw.threadExecutionWidth / x, 1};
|
||||
}
|
||||
// similarly match on x
|
||||
else if (sx >= pwarp) {
|
||||
return {bw.threadExecutionWidth, 1, 1};
|
||||
}
|
||||
// similarly match on xy
|
||||
else if (sx + sy >= pwarp) {
|
||||
NSUInteger x = pow(2, sx);
|
||||
return {x, bw.threadExecutionWidth / x, 1};
|
||||
}
|
||||
|
||||
// on xyz (for most shaders do not protect gid.z, z axis must be accurately match)
|
||||
auto pz = smallest_log2(t.depth);
|
||||
auto sz = bw.zAxisProtected ? ceil(log2(t.depth)) : pz;
|
||||
if (px + py + pz >= pwarp) {
|
||||
NSUInteger x = pow(2, px), y = pow(2, py);
|
||||
return {x, y, bw.threadExecutionWidth / x / y};
|
||||
} else if (sx + sy + sz >= pwarp) {
|
||||
NSUInteger x = pow(2, sx), z = pow(2, MIN(sz, pwarp - sx));
|
||||
return {x, bw.threadExecutionWidth / x / z, z};
|
||||
} else {
|
||||
NSUInteger z = pow(2, sz);
|
||||
return {t.width, t.height, z};
|
||||
}
|
||||
}
|
||||
|
||||
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
|
||||
threads:(MTLSize)threads
|
||||
bandwidth:(MetalBandwidth)bandwidth {
|
||||
[self dispatchEncoder:encoder
|
||||
threads:threads
|
||||
threadsPerGroup:[self threadsPerGroupWithThreads:threads bandwidth:bandwidth]
|
||||
bandwidth:bandwidth];
|
||||
}
|
||||
|
||||
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
|
||||
threads:(MTLSize)threads
|
||||
threadsPerGroup:(MTLSize)threadsPerGroup
|
||||
bandwidth:(MetalBandwidth)bandwidth {
|
||||
#if MNN_METAL_DEBUG
|
||||
if (threads.width == 0 || threads.height == 0 || threads.depth == 0 || threadsPerGroup.width == 0 ||
|
||||
threadsPerGroup.height == 0 || threadsPerGroup.depth == 0) {
|
||||
printf("[METAL] dispatch error %td %td %td / %td %td %td\n", threads.width, threads.height, threads.depth,
|
||||
threadsPerGroup.width, threadsPerGroup.height, threadsPerGroup.depth);
|
||||
}
|
||||
#endif
|
||||
|
||||
// NSLog(@"dispatch {%td %td %td} with {%td %td %td}",
|
||||
// threads.width, threads.height, threads.depth,
|
||||
// threadsPerGroup.width, threadsPerGroup.height, threadsPerGroup.depth);
|
||||
threadsPerGroup.width = MIN(threadsPerGroup.width, bandwidth.maxThreadsPerThreadgroup);
|
||||
threadsPerGroup.height = MIN(threadsPerGroup.height, bandwidth.maxThreadsPerThreadgroup);
|
||||
threadsPerGroup.depth = MIN(threadsPerGroup.depth, bandwidth.maxThreadsPerThreadgroup);
|
||||
#ifdef MNN_BUILD_FOR_IOS
|
||||
if (@available(iOS 11.0, *)) {
|
||||
if ([_device supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily4_v1]) {
|
||||
[encoder dispatchThreads:threads threadsPerThreadgroup:threadsPerGroup];
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
MTLSize groups = {
|
||||
UP_DIV(threads.width, threadsPerGroup.width), UP_DIV(threads.height, threadsPerGroup.height),
|
||||
UP_DIV(threads.depth, threadsPerGroup.depth),
|
||||
};
|
||||
MNN_ASSERT(threadsPerGroup.width >= 1);
|
||||
MNN_ASSERT(threadsPerGroup.height >= 1);
|
||||
MNN_ASSERT(threadsPerGroup.depth >= 1);
|
||||
|
||||
[encoder dispatchThreadgroups:groups threadsPerThreadgroup:threadsPerGroup];
|
||||
}
|
||||
|
||||
#if MNN_METAL_DEBUG
|
||||
#pragma mark debug
|
||||
- (void)printTensor:(const Tensor *)tensor {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
//
|
||||
// MetalArgMax.mm
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2023/12/29.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#import "core/Macro.h"
|
||||
#import "MetalCast.hpp"
|
||||
#import "MetalBackend.hpp"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,425 @@
|
|||
//
|
||||
// MetalAttention.mm
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on b'2024/04/29'.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include <set>
|
||||
#import "core/Macro.h"
|
||||
#import "MetalCast.hpp"
|
||||
#import "MetalBackend.hpp"
|
||||
#import "MNNMetalContext.h"
|
||||
#include "MNN_generated.h"
|
||||
|
||||
#if MNN_METAL_ENABLED
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
|
||||
static const char* gMatMulDivMask = R"metal(
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
using namespace metal;
|
||||
struct Param {
|
||||
int query_seq_len;
|
||||
int key_seq_len;
|
||||
int head_num;
|
||||
int head_dim;
|
||||
float scale;
|
||||
};
|
||||
|
||||
kernel void main0(const device T* input0 [[buffer(0)]],
|
||||
const device T* input1 [[buffer(1)]],
|
||||
device T* output [[buffer(2)]],
|
||||
device T* past_key [[buffer(3)]],
|
||||
const device int* mask [[buffer(4)]],
|
||||
constant Param& param [[buffer(5)]],
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
const int x = gid.x; // query_seq_len
|
||||
const int y = gid.y; // head_num
|
||||
const int z = gid.z; // key_seq_len
|
||||
if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) {
|
||||
return;
|
||||
}
|
||||
int query_seq_len = param.query_seq_len;
|
||||
int key_seq_len = param.key_seq_len;
|
||||
int head_num = param.head_num;
|
||||
int head_dim = param.head_dim;
|
||||
|
||||
const int offset = head_num * head_dim;
|
||||
const int offset_head = y * head_dim;
|
||||
const device T* A_offset = input0 + x * offset + offset_head;
|
||||
device T* Pastkey_offset = past_key + z * offset + offset_head;
|
||||
float Vscale = (float)param.scale;
|
||||
#ifdef FOR_PREFILL
|
||||
device const T* B_offset = input1 + z * offset + offset_head;
|
||||
const int output_offset = y * query_seq_len * key_seq_len;
|
||||
float out0 = 0.0;
|
||||
|
||||
for(int i = 0; i < head_dim; ++i){
|
||||
float A = (float)(A_offset[i]);
|
||||
float B = (float)(B_offset[i]);
|
||||
out0 += B * A;
|
||||
Pastkey_offset[i] = (T)B;
|
||||
}
|
||||
|
||||
out0 *= Vscale;
|
||||
|
||||
out0 = mask[((x + 0) * key_seq_len + (z + 0))] == 0 ? -FLT_MAX : out0;
|
||||
output[output_offset + x * key_seq_len + z] = (T)out0;
|
||||
#else
|
||||
const device T *B_offset = input1 + offset_head;
|
||||
float out = 0.0;
|
||||
if (z == key_seq_len - 1) {
|
||||
for(int i = 0; i < head_dim; ++i){
|
||||
float A = (float)(A_offset[i]);
|
||||
float B = (float)(B_offset[i]);
|
||||
out += B * A;
|
||||
Pastkey_offset[i] = (T)B;
|
||||
}
|
||||
} else {
|
||||
for(int i = 0; i < head_dim; ++i){
|
||||
float A = A_offset[i];
|
||||
float B = (float)Pastkey_offset[i];
|
||||
|
||||
out += A * B;
|
||||
}
|
||||
}
|
||||
out *= Vscale;
|
||||
output[y + z * head_num] = (T)out;
|
||||
#endif
|
||||
}
|
||||
|
||||
)metal";
|
||||
|
||||
|
||||
static const char* gMatMulQKV = R"metal(
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
using namespace metal;
|
||||
struct Param {
|
||||
int query_seq_len;
|
||||
int key_seq_len;
|
||||
int head_num;
|
||||
int head_dim;
|
||||
float scale;
|
||||
};
|
||||
kernel void main0(const device T* input0 [[buffer(0)]],
|
||||
const device T* input1 [[buffer(1)]],
|
||||
device T* output [[buffer(2)]],
|
||||
device T* past_value [[buffer(3)]],
|
||||
constant Param& param [[buffer(4)]],
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
const int x = gid.x; // query_seq_len
|
||||
const int y = gid.y; // head_num
|
||||
const int z = gid.z; // head_dim
|
||||
if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) {
|
||||
return;
|
||||
}
|
||||
int qk_seq_len = param.query_seq_len;
|
||||
int value_seq_len = param.key_seq_len;
|
||||
int head_num = param.head_num;
|
||||
int head_dim = param.head_dim;
|
||||
const int offset = head_num * head_dim;
|
||||
const int offset_head = y * head_dim + z;
|
||||
#ifdef FOR_PREFILL
|
||||
device const T *A_offset = input0 + (y * qk_seq_len + x) * value_seq_len;
|
||||
device const T *B_offset = input1 + offset_head;
|
||||
device T *Pastvalue_offset = past_value + offset_head;
|
||||
float out = 0.0;
|
||||
|
||||
for(int i = 0; i < value_seq_len; ++i){
|
||||
float A0 = (float)A_offset[i];
|
||||
float B = (float)B_offset[i*offset];
|
||||
out += A0 * B;
|
||||
Pastvalue_offset[i*offset] = B;
|
||||
}
|
||||
output[ x * offset + (y * head_dim + z)] = out;
|
||||
#else
|
||||
device const T *A_offset = input0 + y;
|
||||
device const T *B_offset = input1 + offset_head;
|
||||
device T *Pastvalue_offset = past_value + offset_head;
|
||||
float out = 0;
|
||||
|
||||
for(int i = 0; i < value_seq_len - 1; ++i){
|
||||
float A = (float)A_offset[i * head_num];
|
||||
float B = (float)Pastvalue_offset[i * offset];
|
||||
|
||||
out += A * B;
|
||||
}
|
||||
out += (float)A_offset[(value_seq_len - 1)*head_num] * (float)B_offset[0];
|
||||
Pastvalue_offset[(value_seq_len - 1)*offset] = B_offset[0];
|
||||
output[(y * head_dim + z)] = (T)out;
|
||||
#endif
|
||||
|
||||
}
|
||||
)metal";
|
||||
|
||||
namespace MNN {
|
||||
class AttentionBufExecution : public MetalExecution {
|
||||
public:
|
||||
struct SharedCache {
|
||||
std::shared_ptr<Tensor> mPastKey;
|
||||
std::shared_ptr<Tensor> mPastValue;
|
||||
int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0;
|
||||
};
|
||||
AttentionBufExecution(Backend *backend, bool kv_cache);
|
||||
|
||||
virtual ~AttentionBufExecution() = default;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override {
|
||||
if (nullptr == dst) {
|
||||
return true;
|
||||
}
|
||||
auto exe = new AttentionBufExecution(bn, mKVCache);
|
||||
exe->mCache = mCache;
|
||||
*dst = exe;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
void _init();
|
||||
void reallocKVCache();
|
||||
bool mKVCache;
|
||||
std::shared_ptr<SharedCache> mCache;
|
||||
float mScale;
|
||||
const int mExpandChunk = 64;
|
||||
bool mIsDecode = false;
|
||||
std::shared_ptr<Tensor> mTempQK, mTempSoftMax;
|
||||
int mNumHead = 0, mHeadDim = 0, mValueH = 0;
|
||||
id<MTLComputePipelineState> mKernel_softmax;
|
||||
|
||||
id<MTLComputePipelineState> mKernel_qk;
|
||||
id<MTLComputePipelineState> mKernel_qkv;
|
||||
id<MTLComputePipelineState> mKernelPrefill_qk;
|
||||
id<MTLComputePipelineState> mKernelPrefill_qkv;
|
||||
id<MTLBuffer> mParamQKV;
|
||||
id<MTLBuffer> mParamSoftmax;
|
||||
};
|
||||
|
||||
struct Param {
|
||||
int query_seq_len;
|
||||
int key_seq_len;
|
||||
int head_num;
|
||||
int head_dim;
|
||||
float scale;
|
||||
};
|
||||
AttentionBufExecution::AttentionBufExecution(Backend *backend, bool kv_cahce)
|
||||
: MetalExecution(backend) , mKVCache(kv_cahce){
|
||||
_init();
|
||||
}
|
||||
void AttentionBufExecution::_init() {
|
||||
mCache.reset(new SharedCache);
|
||||
auto mtbn = static_cast<MetalBackend *>(backend());
|
||||
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
||||
mKernel_softmax = [context pipelineWithName:@"softmax_plane" fp16:mtbn->useFp16InsteadFp32()];
|
||||
mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly];
|
||||
mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
|
||||
auto rt = mtbn->runtime();
|
||||
std::string T = "float";
|
||||
if (mtbn->useFp16InsteadFp32()) {
|
||||
T = "half";
|
||||
}
|
||||
std::vector<std::string> keys = {
|
||||
"matmul_qk_div_mask",
|
||||
T
|
||||
};
|
||||
auto pipeline = rt->findPipeline(keys);
|
||||
if (nil == pipeline) {
|
||||
// Rebuild Pipeline
|
||||
MTLCompileOptions *decodeOption = [[MTLCompileOptions alloc] init];
|
||||
decodeOption.preprocessorMacros = @{
|
||||
@"T" : @(T.c_str()),
|
||||
};
|
||||
|
||||
MTLCompileOptions* encodeOption = [[MTLCompileOptions alloc] init];
|
||||
encodeOption.preprocessorMacros = @{
|
||||
@"T" : @(T.c_str()),
|
||||
@"FOR_PREFILL": @"1"
|
||||
};
|
||||
mKernel_qk = mtbn->makeComputePipelineWithSourceOption(gMatMulDivMask, "main0", decodeOption);
|
||||
mKernelPrefill_qk = mtbn->makeComputePipelineWithSourceOption(gMatMulDivMask, "main0", encodeOption);
|
||||
mKernel_qkv = mtbn->makeComputePipelineWithSourceOption(gMatMulQKV, "main0", decodeOption);
|
||||
mKernelPrefill_qkv = mtbn->makeComputePipelineWithSourceOption(gMatMulQKV, "main0", encodeOption);
|
||||
|
||||
rt->insertPipeline({"matmul_qk_div_mask", T}, mKernel_qk);
|
||||
rt->insertPipeline({"matmul_qk_div_mask", T, "PREFILL"}, mKernelPrefill_qk);
|
||||
rt->insertPipeline({"matmul_qkv", T}, mKernel_qkv);
|
||||
rt->insertPipeline({"matmul_qkv", T, "PREFILL"}, mKernelPrefill_qkv);
|
||||
} else {
|
||||
mKernel_qk = rt->findPipeline({"matmul_qk_div_mask", T});
|
||||
mKernelPrefill_qk = rt->findPipeline({"matmul_qk_div_mask", T, "PREFILL"});
|
||||
mKernel_qkv = rt->findPipeline({"matmul_qkv", T});
|
||||
mKernelPrefill_qkv = rt->findPipeline({"matmul_qkv", T, "PREFILL"});
|
||||
}
|
||||
MNN_ASSERT(nil != mKernel_qk);
|
||||
MNN_ASSERT(nil != mKernel_qkv);
|
||||
MNN_ASSERT(nil != mKernelPrefill_qk);
|
||||
MNN_ASSERT(nil != mKernelPrefill_qkv);
|
||||
}
|
||||
|
||||
void AttentionBufExecution::reallocKVCache() {
|
||||
if (mCache->mPastLength < mCache->mMaxLength || nullptr == mTempQK || (!mIsDecode)) {
|
||||
if (mIsDecode) {
|
||||
mTempQK.reset(Tensor::createDevice<float>({mNumHead, mCache->mMaxLength}));
|
||||
mTempSoftMax.reset(Tensor::createDevice<float>({mNumHead, mCache->mMaxLength}));
|
||||
} else {
|
||||
mTempQK.reset(Tensor::createDevice<float>({mNumHead, mCache->mPastLength, mCache->mPastLength}));
|
||||
mTempSoftMax.reset(Tensor::createDevice<float>({mNumHead, mCache->mPastLength, mCache->mPastLength}));
|
||||
}
|
||||
backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC);
|
||||
backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC);
|
||||
}
|
||||
if (!mKVCache || mCache->mPastLength < mCache->mMaxLength) {
|
||||
return;
|
||||
}
|
||||
auto mtbn = static_cast<MetalBackend *>(backend());
|
||||
int byte = 4;
|
||||
if(mtbn->useFp16InsteadFp32()) {
|
||||
byte = 2;
|
||||
}
|
||||
bool needCopy = mCache->mMaxLength > 0;
|
||||
|
||||
size_t old_size = mNumHead * mCache->mMaxLength * mHeadDim * byte;
|
||||
mCache->mMaxLength = mCache->mPastLength + mExpandChunk;
|
||||
// past_key: [1, numhead, headdim, maxlen]
|
||||
auto new_key = Tensor::createDevice<float>({mCache->mMaxLength, mNumHead, mHeadDim});
|
||||
// past_value: [1, numhead, maxlen, headdim]
|
||||
auto new_value = Tensor::createDevice<float>({mCache->mMaxLength, mNumHead, mHeadDim});
|
||||
size_t size = mNumHead * mCache->mMaxLength * mHeadDim * byte;
|
||||
backend()->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
backend()->onAcquireBuffer(new_value, Backend::STATIC);
|
||||
if (needCopy) {
|
||||
auto newKeyBuf = MetalBackend::getBuffer(new_key);
|
||||
auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second;
|
||||
auto keyBuf = MetalBackend::getBuffer(mCache->mPastKey.get());
|
||||
auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;;
|
||||
::memcpy(new_key_ptr, key_ptr, old_size);
|
||||
|
||||
auto newValueBuf = MetalBackend::getBuffer(new_value);
|
||||
auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second;
|
||||
auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get());
|
||||
auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second;
|
||||
::memcpy(new_value_ptr, value_ptr, old_size);
|
||||
}
|
||||
mCache->mPastKey.reset(new_key);
|
||||
mCache->mPastValue.reset(new_value);
|
||||
}
|
||||
|
||||
|
||||
void AttentionBufExecution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto query = inputs[0];
|
||||
auto key = inputs[1];
|
||||
auto value = inputs[2];
|
||||
auto mask = inputs[3];
|
||||
auto mtbn = static_cast<MetalBackend *>(backend());
|
||||
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
||||
auto shape = query->shape();
|
||||
|
||||
int seq_len = shape[1];
|
||||
mNumHead = shape[2];
|
||||
mHeadDim = shape[3];
|
||||
mScale = 1.0 / sqrt(mHeadDim);
|
||||
mIsDecode = seq_len == 1;
|
||||
if (mCache->mPastLength == 0 || seq_len > 1) {
|
||||
mCache->mPastLength = seq_len;
|
||||
}
|
||||
mCache->mKv_seq_len = mCache->mPastLength;
|
||||
if(mIsDecode){
|
||||
mCache->mKv_seq_len = mCache->mPastLength + 1;
|
||||
}
|
||||
reallocKVCache();
|
||||
|
||||
// Update Parameters
|
||||
{
|
||||
auto param = (Param*)mParamQKV.contents;
|
||||
param->scale = mScale;
|
||||
param->head_dim = mHeadDim;
|
||||
param->key_seq_len = mCache->mKv_seq_len;
|
||||
param->head_num = mNumHead;
|
||||
param->query_seq_len = seq_len;
|
||||
}
|
||||
// For softmax parameter
|
||||
int inside, outside;
|
||||
if (mIsDecode) {
|
||||
inside = mNumHead;
|
||||
outside = 1;
|
||||
} else {
|
||||
inside = 1;
|
||||
outside = mCache->mKv_seq_len * mNumHead;
|
||||
}
|
||||
int axis = mCache->mKv_seq_len;
|
||||
{
|
||||
auto softmax = (int*)mParamSoftmax.contents;
|
||||
// Inside, axis, outside, plane(invalid)
|
||||
softmax[0] = inside;
|
||||
softmax[1] = axis;
|
||||
softmax[2] = outside;
|
||||
softmax[3] = 0;
|
||||
}
|
||||
// Run QK Kernel
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
if (mIsDecode) {
|
||||
pipeline = mKernel_qk;
|
||||
} else {
|
||||
pipeline = mKernelPrefill_qk;
|
||||
}
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
MetalBackend::setTensor(query, encoder, 0);
|
||||
MetalBackend::setTensor(key, encoder, 1);
|
||||
MetalBackend::setTensor(mTempQK.get(), encoder, 2);
|
||||
MetalBackend::setTensor(mCache->mPastKey.get(), encoder, 3);
|
||||
MetalBackend::setTensor(mask, encoder, 4);
|
||||
[encoder setBuffer:mParamQKV offset:0 atIndex:5];
|
||||
auto gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, mNumHead, mCache->mKv_seq_len)];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
// Run Softmax Kernel
|
||||
{
|
||||
[encoder setComputePipelineState:mKernel_softmax];
|
||||
MetalBackend::setTensor(mTempQK.get(), encoder, 0);
|
||||
MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1);
|
||||
[encoder setBuffer:mParamSoftmax offset:0 atIndex:2];
|
||||
auto gl = [context computeBestGroupAndLocal: mKernel_softmax threads:MTLSizeMake(inside, outside, 1)];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
// Run QKV Kernel
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
if (mIsDecode) {
|
||||
pipeline = mKernel_qkv;
|
||||
} else {
|
||||
pipeline = mKernelPrefill_qkv;
|
||||
}
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0);
|
||||
MetalBackend::setTensor(value, encoder, 1);
|
||||
MetalBackend::setTensor(outputs[0], encoder, 2);
|
||||
MetalBackend::setTensor(mCache->mPastValue.get(), encoder, 3);
|
||||
[encoder setBuffer:mParamQKV offset:0 atIndex:4];
|
||||
auto gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, mNumHead, mHeadDim)];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
// Update status
|
||||
if(mIsDecode){
|
||||
mCache->mPastLength += 1;
|
||||
mCache->mKv_seq_len = mCache->mPastLength + 1;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
class AttentionBufCreator : public MetalBackend::Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *> &outputs) const override {
|
||||
auto param = op->main_as_AttentionParam();
|
||||
return new AttentionBufExecution(backend, param->kv_cache());
|
||||
}
|
||||
};
|
||||
REGISTER_METAL_OP_TRANSFORMER_CREATOR(AttentionBufCreator, OpType_Attention);
|
||||
|
||||
} // namespace MNN
|
||||
#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */
|
||||
#endif
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ public:
|
|||
* @param creator registering creator.
|
||||
*/
|
||||
static void addCreator(OpType type, Creator *creator);
|
||||
static void setTensor(MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index);
|
||||
static void setTensor(const MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index);
|
||||
static std::pair<id<MTLBuffer>, int> getBuffer(MNN::Tensor* tensor);
|
||||
size_t getTensorSizeInBytes(const Tensor* tensor) const;
|
||||
virtual bool onSelectDynamicAllocator(int index, int maxIndex) override;
|
||||
|
|
@ -264,5 +264,10 @@ public:
|
|||
MetalBackend::addCreator(opType, new name); \
|
||||
}
|
||||
|
||||
#define REGISTER_METAL_OP_TRANSFORMER_CREATOR(name, opType) \
|
||||
void ___##name##__##opType##__() { \
|
||||
MetalBackend::addCreator(opType, new name); \
|
||||
}
|
||||
|
||||
#endif /* MNN_METAL_ENABLED */
|
||||
#endif /* MetalBackend_hpp */
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@
|
|||
// Created by MNN on 2019/01/30.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#import "backend/metal/MetalBackend.hpp"
|
||||
#define MNN_METAL
|
||||
#import <MNN/MNNSharedContext.h>
|
||||
#define METAL_CONST_BUFFER_LIMIT 128
|
||||
#if MNN_METAL_ENABLED
|
||||
#import <mutex>
|
||||
#import "backend/metal/MNNMetalContext.h"
|
||||
#import "core/Macro.h"
|
||||
#import "core/TensorUtils.hpp"
|
||||
|
|
@ -126,24 +126,30 @@ size_t MetalBackend::getTensorSizeInBytes(const Tensor* tensor) const {
|
|||
auto format = TensorUtils::getDescribe(tensor)->dimensionFormat;
|
||||
size_t size;
|
||||
if (MNN_DATA_FORMAT_NC4HW4 == format && tensor->dimensions() >= 2) {
|
||||
size_t width = 1;
|
||||
size_t height = 1;
|
||||
auto batch = tensor->length(0);
|
||||
auto channel = tensor->length(1);
|
||||
int width = 1;
|
||||
int height = 1;
|
||||
int batch = tensor->length(0);
|
||||
int channel = tensor->length(1);
|
||||
if (tensor->dimensions() >= 3) {
|
||||
height = tensor->length(2);
|
||||
}
|
||||
for (int i=3; i<tensor->dimensions(); ++i) {
|
||||
width *= tensor->length(i);
|
||||
}
|
||||
auto alignC = ROUND_UP(channel, 4);
|
||||
auto hR = ROUND_UP(height, 4) - height;
|
||||
int alignC = ROUND_UP(channel, 4);
|
||||
int hR = ROUND_UP(height, 4) - height;
|
||||
// width parallel 4, may exceed 3 elements
|
||||
auto wR = ROUND_UP(width + 3, 4) - width;
|
||||
int wR = ROUND_UP(width + 3, 4) - width;
|
||||
int bhw = batch * width * height;
|
||||
int bhwR = UP_DIV(bhw, 16) * 16 - bhw;
|
||||
int extraPadding = ALIMAX(bhwR, (hR * width + wR));
|
||||
size = batch * alignC * width * height;
|
||||
size = size + hR * width * 4 + wR * 4;
|
||||
size = size + extraPadding * 4;
|
||||
} else {
|
||||
size = tensor->elementSize();
|
||||
size = 1;
|
||||
for (int i=0; i<tensor->dimensions(); ++i) {
|
||||
size *= tensor->length(i);
|
||||
}
|
||||
size = ROUND_UP(size, 4);
|
||||
}
|
||||
if (0 == size) {
|
||||
|
|
@ -356,7 +362,7 @@ MTLSize getTensorShape(id<MTLBuffer> shape, const Tensor *tensor) {
|
|||
// shape
|
||||
((int *)shape.contents)[0] = s;
|
||||
((int *)shape.contents)[1] = c;
|
||||
((int *)shape.contents)[2] = z;
|
||||
((int *)shape.contents)[2] = b;
|
||||
((int *)shape.contents)[3] = b * z;
|
||||
|
||||
// threads
|
||||
|
|
@ -459,13 +465,6 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
|
|||
auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
|
||||
auto device = (id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *) (dst->deviceId()))->getBuffer();
|
||||
auto floats = src->getType().code == halide_type_float;
|
||||
std::unique_ptr<Tensor> tempSrc;
|
||||
if (sfmt == dfmt && sfmt == MNN_DATA_FORMAT_NC4HW4) {
|
||||
tempSrc.reset(new Tensor(src, Tensor::CAFFE));
|
||||
MNNCPUCopyBuffer(src, tempSrc.get());
|
||||
src = tempSrc.get();
|
||||
sfmt = TensorUtils::getDescribe(src)->dimensionFormat;
|
||||
}
|
||||
// For command queue from user, need user to make sure last frame's gpu work is ready
|
||||
bool needWait = !mRuntime->userSync();
|
||||
// cast
|
||||
|
|
@ -486,7 +485,8 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
|
|||
};
|
||||
::memcpy(mShapeH2D.contents, limits, sizeof(limits));
|
||||
auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
|
||||
auto bandwidth = [ctx load: @"downcast_float4" encoder:encoder fp16:mUseFloatAsFp16];
|
||||
auto pipeline = [ctx pipelineWithName:@"downcast_float4" fp16:mUseFloatAsFp16];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
|
||||
[encoder setBuffer:host offset:0 atIndex:0];
|
||||
[encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
|
||||
|
|
@ -494,7 +494,7 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
|
|||
//[ctx dispatchEncoder:encoder threads:{sizeC4, 1, 1} bandwidth:bandwidth];
|
||||
std::pair<MTLSize, MTLSize> threads;
|
||||
threads.first = {sizeC4, 1, 1};
|
||||
threads.second = {bandwidth.maxThreadsPerThreadgroup, 1, 1};
|
||||
threads.second = {[pipeline maxTotalThreadsPerThreadgroup], 1, 1};
|
||||
threads.second.width = threads.second.width <= threads.first.width ? threads.second.width : threads.first.width;
|
||||
threads.first.width = UP_DIV(threads.first.width, threads.second.width);
|
||||
[encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second];
|
||||
|
|
@ -520,13 +520,14 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
|
|||
auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
|
||||
auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Down);
|
||||
MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt
|
||||
|
||||
auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
|
||||
auto pipeline = [ctx pipelineWithName:kernel fp16:mUseFloatAsFp16];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
|
||||
[encoder setBuffer:buffer offset:0 atIndex:0];
|
||||
[encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:mShapeH2D offset:0 atIndex:2];
|
||||
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
|
||||
auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
[encoder endEncoding];
|
||||
commit();
|
||||
//[ctx wait];
|
||||
|
|
@ -539,16 +540,6 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
|
|||
auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
|
||||
auto device = (id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)src->deviceId())->getBuffer();
|
||||
auto floats = src->getType().code == halide_type_float;
|
||||
std::shared_ptr<Tensor> tempDst;
|
||||
if (sfmt == dfmt && sfmt == MNN_DATA_FORMAT_NC4HW4) {
|
||||
tempDst.reset(new Tensor(dst, Tensor::CAFFE), [dst](void* t) {
|
||||
auto tensor = (Tensor*)t;
|
||||
MNNCPUCopyBuffer(tensor, dst);
|
||||
delete tensor;
|
||||
});
|
||||
dst = tempDst.get();
|
||||
dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
|
||||
}
|
||||
// cast
|
||||
if (sfmt == dfmt || src->dimensions() <= 1) {
|
||||
if (floats && mUseFloatAsFp16) {
|
||||
|
|
@ -558,7 +549,8 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
|
|||
|
||||
NSUInteger size = src->elementSize();
|
||||
auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
|
||||
auto bandwidth = [ctx load: @"upcast_float4" encoder:encoder fp16:mUseFloatAsFp16];
|
||||
auto pipeline = [ctx pipelineWithName:@"upcast_float4" fp16:mUseFloatAsFp16];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:buffer offset:0 atIndex:1];
|
||||
auto sizeC4 = UP_DIV(size, 4);
|
||||
|
|
@ -573,7 +565,7 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
|
|||
//[ctx dispatchEncoder:encoder threads:{sizeC4, 1, 1} bandwidth:bandwidth];
|
||||
std::pair<MTLSize, MTLSize> threads;
|
||||
threads.first = {sizeC4, 1, 1};
|
||||
threads.second = {bandwidth.maxThreadsPerThreadgroup, 1, 1};
|
||||
threads.second = {[pipeline maxTotalThreadsPerThreadgroup], 1, 1};
|
||||
threads.second.width = threads.second.width <= threads.first.width ? threads.second.width : threads.first.width;
|
||||
threads.first.width = UP_DIV(threads.first.width, threads.second.width);
|
||||
[encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second];
|
||||
|
|
@ -597,18 +589,28 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
|
|||
auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Up);
|
||||
MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt
|
||||
|
||||
auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
|
||||
auto pipeline = [ctx pipelineWithName:kernel fp16:mUseFloatAsFp16];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:buffer offset:0 atIndex:1];
|
||||
[encoder setBuffer:mShapeD2H offset:0 atIndex:2];
|
||||
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
|
||||
auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
[encoder endEncoding];
|
||||
commit();
|
||||
wait();
|
||||
memcpy(dst->host<float>(), buffer.contents, dst->size());
|
||||
}
|
||||
}
|
||||
|
||||
static const char* gCopy = R"metal(
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
using namespace metal;
|
||||
kernel void main0(const device int4 *in [[buffer(0)]], device int4 *out [[buffer(1)]], constant uint4& limit [[buffer(2)]], uint gid [[thread_position_in_grid]]) {
|
||||
if (gid < limit.x) {
|
||||
out[int(gid)] = in[int(gid)];
|
||||
}
|
||||
})metal";
|
||||
void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
|
||||
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const {
|
||||
auto ctx = (__bridge MNNMetalContext *)context();
|
||||
|
|
@ -619,12 +621,28 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
|
|||
|
||||
// copy
|
||||
if (sfmt == dfmt || src->dimensions() <= 1) {
|
||||
auto flt = dst->getType().code == halide_type_float;
|
||||
auto size = flt ? dst->elementSize() : dst->size();
|
||||
auto bandwidth = [ctx load:flt ? @"copy_float" : @"copy_byte" encoder:encoder fp16:mUseFloatAsFp16];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)src->deviceId())->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)dst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
|
||||
[ctx dispatchEncoder:encoder threads:{(NSUInteger)size, 1, 1} bandwidth:bandwidth];
|
||||
auto size = dst->usize();
|
||||
if (mUseFloatAsFp16 && dst->getType().code == halide_type_float) {
|
||||
size = size / 2;
|
||||
}
|
||||
size = UP_DIV(size, (4 * sizeof(float)));
|
||||
std::vector<std::string> keys = {
|
||||
"copyC4"
|
||||
};
|
||||
id<MTLComputePipelineState> pipeline = mRuntime->findPipeline(keys);
|
||||
if (nil == pipeline) {
|
||||
pipeline = makeComputePipelineWithSourceOption(gCopy, "main0", nil);
|
||||
mRuntime->insertPipeline(keys, pipeline);
|
||||
}
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
if (shape == nil) {
|
||||
shape = getConstBuffer(4 * sizeof(int));
|
||||
}
|
||||
((uint32_t*)[shape contents])[0] = size;
|
||||
setTensor(src, encoder, 0);
|
||||
setTensor(dst, encoder, 1);
|
||||
[encoder setBuffer:shape offset:0 atIndex:2];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(size, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
|
||||
}
|
||||
// convert
|
||||
else {
|
||||
|
|
@ -635,11 +653,13 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
|
|||
}
|
||||
|
||||
auto size = getTensorShape(shape, src);
|
||||
auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
|
||||
auto pipeline = [ctx pipelineWithName:kernel fp16:mUseFloatAsFp16];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(src->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(dst->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:shape offset:0 atIndex:2];
|
||||
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
|
||||
auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
|
||||
if (standalone) {
|
||||
|
|
@ -712,7 +732,7 @@ id<MTLCommandBuffer> MetalBackend::getCommandBufferForNet() const {
|
|||
return _commandBuffer_net;
|
||||
}
|
||||
|
||||
void MetalBackend::setTensor(MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index) {
|
||||
void MetalBackend::setTensor(const MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index) {
|
||||
[encoder setBuffer:((MetalRuntimeAllocator::MetalBufferAlloc *)tensor->deviceId())->getBuffer() offset:TensorUtils::getDescribe(tensor)->extra.offset atIndex:index];
|
||||
}
|
||||
std::pair<id<MTLBuffer>, int> MetalBackend::getBuffer(MNN::Tensor* tensor) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ def genRegister():
|
|||
f.write(" namespace MNN {\n")
|
||||
f.write("#if MNN_METAL_ENABLED\n")
|
||||
funcs=[]
|
||||
transformerFuncs = []
|
||||
for shapath in shaders:
|
||||
with open(shapath,"r") as sha:
|
||||
lines=sha.readlines()
|
||||
|
|
@ -27,10 +28,22 @@ def genRegister():
|
|||
funcname="___"+x[0]+"__"+x[1]+"__();"
|
||||
funcs.append(funcname)
|
||||
f.write(" extern void "+funcname+"\n")
|
||||
elif l.startswith('REGISTER_METAL_OP_TRANSFORMER_CREATOR('):
|
||||
x=l.replace("REGISTER_METAL_OP_TRANSFORMER_CREATOR(","").replace(")","").replace(" ","").replace(";","").replace("\n","").split(",")
|
||||
funcname="___"+x[0]+"__"+x[1]+"__();"
|
||||
transformerFuncs.append(funcname)
|
||||
f.write("#ifdef MNN_SUPPORT_TRANSFORMER_FUSE\n")
|
||||
f.write(" extern void "+funcname+"\n")
|
||||
f.write('#endif\n')
|
||||
|
||||
pass
|
||||
f.write("void registerMetalOps() {\n")
|
||||
for func in funcs:
|
||||
f.write(" "+func+"\n")
|
||||
f.write('#ifdef MNN_SUPPORT_TRANSFORMER_FUSE\n')
|
||||
for func in transformerFuncs:
|
||||
f.write(" "+func+"\n")
|
||||
f.write('#endif\n')
|
||||
f.write("}\n#endif\n}")
|
||||
if os.path.isdir(renderPath):
|
||||
shaders=[]
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ public:
|
|||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
|
||||
protected:
|
||||
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
|
||||
private:
|
||||
MetalConvolution(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias);
|
||||
|
|
|
|||
|
|
@ -175,14 +175,13 @@ ErrorCode MetalConvolution::onResize(const std::vector<Tensor *> &inputs, const
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
void MetalConvolution::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto oc_4 = UP_DIV(output->channel(), 4);
|
||||
void MetalConvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
||||
auto bandwidth = (MetalBandwidth){mPipeline.threadExecutionWidth, mPipeline.maxTotalThreadsPerThreadgroup, NO};
|
||||
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
|
||||
MetalBackend::setTensor(input, encoder, 0);
|
||||
MetalBackend::setTensor(output, encoder, 1);
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
|
||||
MetalBackend::setTensor(mWeight.get(), encoder, 3);
|
||||
MetalBackend::setTensor(mBias.get(), encoder, 4);
|
||||
|
|
|
|||
|
|
@ -21,10 +21,9 @@ public:
|
|||
virtual ~MetalConvolution1x1() = default;
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
protected:
|
||||
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
private:
|
||||
MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, std::shared_ptr<MNN::Tensor> dequantBias, int dequantBits);
|
||||
MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits);
|
||||
id<MTLComputePipelineState> mPipeline;
|
||||
std::pair<MTLSize, MTLSize> mThreads;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -31,11 +31,10 @@ MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op) :
|
|||
loadWeight(op->main_as_Convolution2D(), ldInt8Weight);
|
||||
}
|
||||
|
||||
MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, std::shared_ptr<MNN::Tensor> dequantBias, int dequantBits) : MetalConvolutionCommon(backend, op, bias) {
|
||||
MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits) : MetalConvolutionCommon(backend, op, bias) {
|
||||
mWeight = weight;
|
||||
mBias = bias;
|
||||
mDequantScale = dequantScale;
|
||||
mDequantZero = dequantBias;
|
||||
mDequantScaleBias = dequantScale;
|
||||
mDequantBits = dequantBits;
|
||||
}
|
||||
|
||||
|
|
@ -47,7 +46,7 @@ bool MetalConvolution1x1::onClone(Backend* bn, const Op* op, Execution** dst) {
|
|||
if (nullptr == dst) {
|
||||
return true;
|
||||
}
|
||||
*dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScale, mDequantZero, mDequantBits);
|
||||
*dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -55,15 +54,18 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
MetalConvolutionCommon::onResize(inputs, outputs);
|
||||
|
||||
// prepare
|
||||
// For C4NHW4 format, NHW can be fuse to W
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
|
||||
auto is = input->width() * input->height();
|
||||
auto ic_4 = UP_DIV(input->channel(), 4);
|
||||
auto ow = output->width();
|
||||
auto oh = output->height();
|
||||
auto os = ow * oh;
|
||||
auto ob = output->batch();
|
||||
int is = input->batch();
|
||||
for (int i=2; i<input->dimensions(); ++i) {
|
||||
is *= input->length(i);
|
||||
}
|
||||
int ic_4 = UP_DIV(input->channel(), 4);
|
||||
int ow = is;
|
||||
int oh = 1;
|
||||
int os = ow;
|
||||
int ob = 1;
|
||||
auto oc = output->channel();
|
||||
auto oc_4 = UP_DIV(output->channel(), 4);
|
||||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
|
|
@ -75,7 +77,7 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
::memcpy(mConstBuffer.contents, constants, sizeof(constants));
|
||||
|
||||
MetalRuntime* rt = (MetalRuntime *)backend->runtime();
|
||||
if (ow == input->width() && oh == input->height() && mDequantScale.get()) {
|
||||
if (mDequantScaleBias.get()) {
|
||||
NSUInteger gid_x = UP_DIV(ow * oh, 4);
|
||||
NSUInteger gid_y = oc_4;
|
||||
NSUInteger gid_z = ob;
|
||||
|
|
@ -89,8 +91,8 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
(id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(),
|
||||
mConstBuffer, (((MetalRuntimeAllocator::MetalBufferAlloc *)mWeight->deviceId()))->getBuffer(),
|
||||
((MetalRuntimeAllocator::MetalBufferAlloc *)mBias->deviceId())->getBuffer(),
|
||||
(((MetalRuntimeAllocator::MetalBufferAlloc *)mDequantScale->deviceId()))->getBuffer(),
|
||||
(((MetalRuntimeAllocator::MetalBufferAlloc *)mDequantZero->deviceId()))->getBuffer(), nil];
|
||||
(((MetalRuntimeAllocator::MetalBufferAlloc *)mDequantScaleBias->deviceId()))->getBuffer(),
|
||||
nil];
|
||||
const Tensor* weight = mWeight.get();
|
||||
const Tensor* bias = mBias.get();
|
||||
int buffer_offset[] = {
|
||||
|
|
@ -99,8 +101,7 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
0,
|
||||
TensorUtils::getDescribe(weight)->extra.offset,
|
||||
TensorUtils::getDescribe(bias)->extra.offset,
|
||||
TensorUtils::getDescribe(mDequantScale.get())->extra.offset,
|
||||
TensorUtils::getDescribe(mDequantZero.get())->extra.offset,
|
||||
TensorUtils::getDescribe(mDequantScaleBias.get())->extra.offset,
|
||||
0};
|
||||
|
||||
MetalRuntime *rt = (MetalRuntime *)backend->runtime();
|
||||
|
|
@ -147,9 +148,8 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
//printf("conv1x1_z4, %d %d %d %d\n", ow, oh, oc_4, ic_4);
|
||||
}
|
||||
} else {
|
||||
NSString* shaderName[] = {@"conv1x1_w4h2", @"conv1x1_w2h2", @"conv1x1_w4h4", @"conv1x1_w2c2", @"conv1x1_w2h2c2"};
|
||||
int itemW[] = {4, 2, 4, 2, 2};
|
||||
int itemH[] = {2, 2, 4, 1, 2};
|
||||
NSString* shaderName[] = {@"conv1x1_g1z8", @"conv1x1_g1z4", @"conv1x1_w4h4", @"conv1x1_w2c2", @"conv1x1_w4c2"};
|
||||
int itemW[] = {8, 4, 16, 2, 4};
|
||||
int itemC[] = {4, 4, 4, 8, 8};
|
||||
int actual_kernel = 5;
|
||||
if (oc_4 % 2 != 0) {
|
||||
|
|
@ -168,8 +168,8 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
|
||||
id<MTLComputePipelineState> pipeline = [context pipelineWithName:shaderName[knl_idx] fp16:backend->useFp16InsteadFp32()];
|
||||
NSUInteger gid_x = UP_DIV(ow, itemW[knl_idx]);
|
||||
NSUInteger gid_y = UP_DIV(oh, itemH[knl_idx]);
|
||||
NSUInteger gid_z = ob * UP_DIV(oc, itemC[knl_idx]);
|
||||
NSUInteger gid_y = UP_DIV(oc, itemC[knl_idx]);
|
||||
NSUInteger gid_z = 1;
|
||||
|
||||
std::string name = [shaderName[knl_idx] UTF8String];
|
||||
auto ret = [context getGridAndThreadgroup:pipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name offsets:buffer_offset queue:backend->queue()];
|
||||
|
|
@ -188,16 +188,17 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
void MetalConvolution1x1::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
|
||||
void MetalConvolution1x1::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
|
||||
MetalBackend::setTensor(mWeight.get(), encoder, 3);
|
||||
MetalBackend::setTensor(mBias.get(), encoder, 4);
|
||||
if (mDequantScale && mDequantZero) {
|
||||
MetalBackend::setTensor(mDequantScale.get(), encoder, 5);
|
||||
MetalBackend::setTensor(mDequantZero.get(), encoder, 6);
|
||||
if (mDequantScaleBias) {
|
||||
MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 5);
|
||||
}
|
||||
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,12 +20,10 @@ class MetalConvolutionCommon : public MetalExecution {
|
|||
public:
|
||||
MetalConvolutionCommon(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> bias);
|
||||
virtual ~MetalConvolutionCommon() = default;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
|
||||
protected:
|
||||
void loadWeight(const MNN::Convolution2D *conv, bool loadWeightInt8 = false);
|
||||
|
||||
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) = 0;
|
||||
virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight = false, bool int4Weight = false);
|
||||
|
||||
private:
|
||||
|
|
@ -42,8 +40,7 @@ protected:
|
|||
|
||||
std::shared_ptr<MNN::Tensor> mWeight;
|
||||
std::shared_ptr<MNN::Tensor> mBias;
|
||||
std::shared_ptr<MNN::Tensor> mDequantScale;
|
||||
std::shared_ptr<MNN::Tensor> mDequantZero;
|
||||
std::shared_ptr<MNN::Tensor> mDequantScaleBias;
|
||||
int mDequantBits;
|
||||
id<MTLBuffer> mConstBuffer = nil;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -68,10 +68,6 @@ MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *
|
|||
}
|
||||
}
|
||||
|
||||
void MetalConvolutionCommon::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
return onFloat(inputs[0], outputs[0], encoder);
|
||||
}
|
||||
|
||||
template <typename FType, typename TType>
|
||||
void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src, uint8_t* dstOrigion) {
|
||||
auto goc = oc / group;
|
||||
|
|
@ -101,25 +97,24 @@ void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src,
|
|||
}
|
||||
}
|
||||
|
||||
static std::vector<std::shared_ptr<MNN::Tensor>> getDequantScale(float* scale, int size, MetalBackend *backend, bool asymmetric) {
|
||||
static std::shared_ptr<MNN::Tensor> getDequantScale(float* scale, int size, MetalBackend *backend, bool asymmetric) {
|
||||
int outputCount = 0;
|
||||
if (asymmetric) {
|
||||
outputCount = size / 2;
|
||||
} else {
|
||||
outputCount = size;
|
||||
}
|
||||
std::vector<std::shared_ptr<MNN::Tensor>> scaleBias(2);
|
||||
std::shared_ptr<MNN::Tensor> dequantScale(MNN::Tensor::createDevice<uint8_t>({outputCount * 4}));
|
||||
std::shared_ptr<MNN::Tensor> dequantBias(MNN::Tensor::createDevice<uint8_t>({outputCount * 4}));
|
||||
bool res = backend->onAcquireBuffer(dequantScale.get(), Backend::STATIC) && backend->onAcquireBuffer(dequantBias.get(), Backend::STATIC);
|
||||
int alignOutputCount = ALIGN_UP4(outputCount);
|
||||
std::shared_ptr<MNN::Tensor> dequantScale(MNN::Tensor::createDevice<uint8_t>({(int)(alignOutputCount * sizeof(float) * 2)}));
|
||||
bool res = backend->onAcquireBuffer(dequantScale.get(), Backend::STATIC);
|
||||
if (!res) {
|
||||
MNN_ERROR("Buffer allocated error!\n");
|
||||
return scaleBias;
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer0 = MetalBackend::getBuffer(dequantScale.get());
|
||||
auto dst_scale = (uint8_t*)[buffer0.first contents] + buffer0.second;
|
||||
auto buffer1 = MetalBackend::getBuffer(dequantBias.get());
|
||||
auto dst_bias = (uint8_t*)[buffer1.first contents] + buffer1.second;
|
||||
::memset(dst_scale, 0, alignOutputCount * 2 * sizeof(float));
|
||||
auto dst_bias = dst_scale + alignOutputCount * sizeof(float);
|
||||
for (int o = 0; o < outputCount; ++o) {
|
||||
float min = 0.0f;
|
||||
float alpha = 0.0f;
|
||||
|
|
@ -131,10 +126,8 @@ static std::vector<std::shared_ptr<MNN::Tensor>> getDequantScale(float* scale, i
|
|||
}
|
||||
((float*)dst_scale)[o] = alpha;
|
||||
((float*)dst_bias)[o] = min;
|
||||
}
|
||||
scaleBias[0] = dequantScale;
|
||||
scaleBias[1] = dequantBias;
|
||||
return scaleBias;
|
||||
}
|
||||
return dequantScale;
|
||||
}
|
||||
void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv, bool loadWeightInt8) {
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
|
||||
|
|
@ -157,8 +150,7 @@ void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv, bool loa
|
|||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
mWeight = weightTransform(group, oc, ic, kh, kw, (float*)qnt->weight.get(), !qnt->canUseInt4, qnt->canUseInt4);
|
||||
auto dequantParams = getDequantScale(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric);
|
||||
mDequantScale = dequantParams[0];
|
||||
mDequantZero = dequantParams[1];
|
||||
mDequantScaleBias = dequantParams;
|
||||
mDequantBits = qnt->canUseInt4 ? 4:8;
|
||||
} else if (qnt && qnt->weightFloat.size() > 0) {
|
||||
mWeight = weightTransform(group, oc, ic, kh, kw, qnt->weightFloat.get(), false, false);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ public:
|
|||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
protected:
|
||||
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight=false, bool int4Weight=false) override;
|
||||
private:
|
||||
id<MTLComputePipelineState> mPipeline;
|
||||
|
|
|
|||
|
|
@ -82,12 +82,10 @@ ErrorCode MetalConvolutionDepthwise::onResize(const std::vector<Tensor *> &input
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
void MetalConvolutionDepthwise::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
|
||||
void MetalConvolutionDepthwise::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset: TensorUtils::getDescribe(input)->extra.offset
|
||||
atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset: TensorUtils::getDescribe(output)->extra.offset
|
||||
atIndex:1];
|
||||
MetalBackend::setTensor(inputs[0], encoder, 0);
|
||||
MetalBackend::setTensor(outputs[0], encoder, 1);
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
|
||||
MetalBackend::setTensor(mWeight.get(), encoder, 3);
|
||||
MetalBackend::setTensor(mBias.get(), encoder, 4);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ public:
|
|||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
|
||||
protected:
|
||||
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight=false, bool int4Weight=false) override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -141,32 +141,40 @@ ErrorCode MetalConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
void MetalConvolutionWinograd::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
|
||||
void MetalConvolutionWinograd::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto input = inputs[0];
|
||||
auto output = outputs[0];
|
||||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
auto context = (__bridge MNNMetalContext *)backend->context();
|
||||
|
||||
{ // transform
|
||||
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:1];
|
||||
auto pipeline = [context pipelineWithName:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
MetalBackend::setTensor(input, encoder, 0);
|
||||
MetalBackend::setTensor(mTempSrc.get(), encoder, 1);
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
|
||||
[context dispatchEncoder:encoder threads:mInputTransformThreads bandwidth:bandwidth];
|
||||
auto gl = [context computeBestGroupAndLocal:pipeline threads:mInputTransformThreads];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
{ // gemm
|
||||
auto bandwidth = [context load:@"matmul4x4" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:1];
|
||||
auto pipeline = [context pipelineWithName:@"matmul4x4" fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
MetalBackend::setTensor(mTempSrc.get(), encoder, 0);
|
||||
MetalBackend::setTensor(mTempDst.get(), encoder, 1);
|
||||
MetalBackend::setTensor(mWeight.get(), encoder, 2);
|
||||
[encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
|
||||
[context dispatchEncoder:encoder threads:mMatMulThreads bandwidth:bandwidth];
|
||||
auto gl = [context computeBestGroupAndLocal:pipeline threads:mMatMulThreads];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
{ // transform
|
||||
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:0];
|
||||
auto pipeline = [context pipelineWithName:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
MetalBackend::setTensor(mTempDst.get(), encoder, 0);
|
||||
MetalBackend::setTensor(mBias.get(), encoder, 1);
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
|
||||
MetalBackend::setTensor(output, encoder, 2);
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
|
||||
[context dispatchEncoder:encoder threads:mOutputTransformThreads bandwidth:bandwidth];
|
||||
auto gl = [context computeBestGroupAndLocal:pipeline threads:mOutputTransformThreads];
|
||||
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
||||
}
|
||||
}
|
||||
std::shared_ptr<MNN::Tensor> MetalConvolutionWinograd::weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight, bool int4Weight) {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
//
|
||||
// MetalExecution.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2023/11/09.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifndef MetalExecution_hpp
|
||||
#define MetalExecution_hpp
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
//
|
||||
// MetalExecution.mm
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2023/11/09.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include "MetalExecution.hpp"
|
||||
#import "backend/metal/MetalBackend.hpp"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
//
|
||||
// MetalLoop.mm
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2023/12/28.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#import "core/Macro.h"
|
||||
#import "MetalCast.hpp"
|
||||
#import "MetalBinary.hpp"
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ namespace MNN {
|
|||
|
||||
class MetalMatMul : public MetalExecution {
|
||||
public:
|
||||
MetalMatMul(Backend *backend, const MatMul *matmul);
|
||||
MetalMatMul(Backend *backend, const MatMul *matmul, bool withBias);
|
||||
virtual ~MetalMatMul();
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
|
@ -27,6 +27,8 @@ private:
|
|||
id<MTLBuffer> mConstBuffer = nil;
|
||||
bool mTransposeA = false;
|
||||
bool mTransposeB = false;
|
||||
id<MTLComputePipelineState> mPipeline;
|
||||
std::pair<MTLSize, MTLSize> mThreads;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
|
|
|||
|
|
@ -18,11 +18,17 @@ struct matP {
|
|||
int size[4];
|
||||
int stride[4];
|
||||
};
|
||||
MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul) : MetalExecution(backend) {
|
||||
MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul, bool withBias) : MetalExecution(backend) {
|
||||
mTransposeA = matmul->transposeA();
|
||||
mTransposeB = matmul->transposeB();
|
||||
auto mkbn = static_cast<MetalBackend *>(backend);
|
||||
mConstBuffer = mkbn->getConstBuffer(sizeof(matP));
|
||||
auto context = (__bridge MNNMetalContext *)mkbn->context();
|
||||
if (withBias) {
|
||||
mPipeline = [context pipelineWithName:@"matmul_bias" fp16:mkbn->useFp16InsteadFp32()];
|
||||
} else {
|
||||
mPipeline = [context pipelineWithName:@"matmul" fp16:mkbn->useFp16InsteadFp32()];
|
||||
}
|
||||
}
|
||||
MetalMatMul::~MetalMatMul() {
|
||||
auto mkbn = static_cast<MetalBackend *>(backend());
|
||||
|
|
@ -59,6 +65,9 @@ ErrorCode MetalMatMul::onResize(const std::vector<Tensor *> &inputs, const std::
|
|||
}
|
||||
|
||||
::memcpy(mConstBuffer.contents, &buffer, sizeof(matP));
|
||||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
|
||||
mThreads = [context computeBestGroupAndLocal:mPipeline threads: MTLSizeMake(h, e, 1)];
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
|
@ -71,25 +80,20 @@ void MetalMatMul::onEncode(const std::vector<Tensor *> &inputs, const std::vecto
|
|||
auto h = C->length(1);
|
||||
|
||||
if (inputs.size() > 2) {
|
||||
auto bandwidth = [context load:@"matmul_bias" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[2]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[2])->extra.offset atIndex:2];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:3];
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
MetalBackend::setTensor(input0, encoder, 0);
|
||||
MetalBackend::setTensor(input1, encoder, 1);
|
||||
MetalBackend::setTensor(inputs[2], encoder, 2);
|
||||
MetalBackend::setTensor(output, encoder, 3);
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:4];
|
||||
[context dispatchEncoder:encoder
|
||||
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
|
||||
bandwidth:bandwidth];
|
||||
} else {
|
||||
auto bandwidth = [context load:@"matmul" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
MetalBackend::setTensor(input0, encoder, 0);
|
||||
MetalBackend::setTensor(input1, encoder, 1);
|
||||
MetalBackend::setTensor(output, encoder, 2);
|
||||
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
|
||||
[context dispatchEncoder:encoder
|
||||
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
|
||||
bandwidth:bandwidth];
|
||||
}
|
||||
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
|
||||
}
|
||||
|
||||
class MetalMatMulCreator : public MetalBackend::Creator {
|
||||
|
|
@ -100,7 +104,7 @@ public:
|
|||
MNN_PRINT("metal not support matmul inpt size less than 2\n");
|
||||
return nullptr;
|
||||
}
|
||||
return new MetalMatMul(backend, op->main_as_MatMul());
|
||||
return new MetalMatMul(backend, op->main_as_MatMul(), inputs.size() > 2);
|
||||
}
|
||||
};
|
||||
REGISTER_METAL_OP_CREATOR(MetalMatMulCreator, OpType_MatMul);
|
||||
|
|
|
|||
|
|
@ -12,13 +12,15 @@
|
|||
extern void ___MetalEltwiseCreator__OpType_Eltwise__();
|
||||
extern void ___MetalConvolutionCreator__OpType_Convolution__();
|
||||
extern void ___MetalLayerNormCreator__OpType_LayerNorm__();
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
extern void ___AttentionBufCreator__OpType_Attention__();
|
||||
#endif
|
||||
extern void ___MetalMatMulCreator__OpType_MatMul__();
|
||||
extern void ___MetalBinaryCreator__OpType_BinaryOp__();
|
||||
extern void ___MetalConvolutionDepthwiseCreator__OpType_ConvolutionDepthwise__();
|
||||
extern void ___MetalDeconvolutionCreator__OpType_Deconvolution__();
|
||||
extern void ___MetalDeconvolutionCreator__OpType_DeconvolutionDepthwise__();
|
||||
extern void ___MetalLoopCreator__OpType_While__();
|
||||
extern void ___MetalReLUCreator__OpType_ReLU__();
|
||||
extern void ___MetalPoolingCreator__OpType_Pooling__();
|
||||
extern void ___MetalScaleCreator__OpType_Scale__();
|
||||
extern void ___MetalInterpCreator__OpType_Interp__();
|
||||
|
|
@ -31,6 +33,7 @@
|
|||
extern void ___MetalFuseCreator__OpType_Extra__();
|
||||
extern void ___MetalPReLUCreator__OpType_PReLU__();
|
||||
extern void ___MetalReLU6Creator__OpType_ReLU6__();
|
||||
extern void ___MetalReLU6Creator__OpType_ReLU__();
|
||||
void registerMetalOps() {
|
||||
___MetalArgMaxCreator__OpType_ArgMax__();
|
||||
___MetalArgMaxCreator__OpType_ArgMin__();
|
||||
|
|
@ -48,7 +51,6 @@ void registerMetalOps() {
|
|||
___MetalDeconvolutionCreator__OpType_Deconvolution__();
|
||||
___MetalDeconvolutionCreator__OpType_DeconvolutionDepthwise__();
|
||||
___MetalLoopCreator__OpType_While__();
|
||||
___MetalReLUCreator__OpType_ReLU__();
|
||||
___MetalPoolingCreator__OpType_Pooling__();
|
||||
___MetalScaleCreator__OpType_Scale__();
|
||||
___MetalInterpCreator__OpType_Interp__();
|
||||
|
|
@ -61,6 +63,10 @@ void registerMetalOps() {
|
|||
___MetalFuseCreator__OpType_Extra__();
|
||||
___MetalPReLUCreator__OpType_PReLU__();
|
||||
___MetalReLU6Creator__OpType_ReLU6__();
|
||||
___MetalReLU6Creator__OpType_ReLU__();
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
___AttentionBufCreator__OpType_Attention__();
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
@ -24,6 +24,7 @@ public:
|
|||
private:
|
||||
float mSpatialScale;
|
||||
id<MTLBuffer> mShape;
|
||||
id<MTLComputePipelineState> mPipeline;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@ namespace MNN {
|
|||
|
||||
MetalROIPooling::MetalROIPooling(Backend *backend, float spatialScale)
|
||||
: MetalExecution(backend), mSpatialScale(spatialScale) {
|
||||
// nothing to do
|
||||
auto mtbn = static_cast<MetalBackend *>(backend);
|
||||
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
||||
mShape = [context newDeviceBuffer:8 * sizeof(int) + sizeof(float) access:CPUWriteOnly];
|
||||
mPipeline = [context pipelineWithName:@"ROI_pooling" fp16:mtbn->useFp16InsteadFp32()];
|
||||
}
|
||||
ErrorCode MetalROIPooling::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
|
|
@ -25,16 +28,15 @@ ErrorCode MetalROIPooling::onResize(const std::vector<Tensor *> &inputs, const s
|
|||
int iw = input->width(), ih = input->height();
|
||||
int ow = output->width(), oh = output->height(), oz = UP_DIV(output->channel(), 4), ob = output->batch();
|
||||
|
||||
auto shape = [context newDeviceBuffer:7 * sizeof(int) + sizeof(float) access:CPUWriteOnly];
|
||||
((int *)shape.contents)[0] = iw;
|
||||
((int *)shape.contents)[1] = ih;
|
||||
((int *)shape.contents)[2] = iw * ih;
|
||||
((int *)shape.contents)[3] = ow;
|
||||
((int *)shape.contents)[4] = oh;
|
||||
((int *)shape.contents)[5] = ow * oh;
|
||||
((int *)shape.contents)[6] = oz;
|
||||
((float *)shape.contents)[7] = mSpatialScale;
|
||||
mShape = shape;
|
||||
((int *)mShape.contents)[0] = iw;
|
||||
((int *)mShape.contents)[1] = ih;
|
||||
((int *)mShape.contents)[2] = iw * ih;
|
||||
((int *)mShape.contents)[3] = input->batch();
|
||||
((int *)mShape.contents)[4] = ow;
|
||||
((int *)mShape.contents)[5] = oh;
|
||||
((int *)mShape.contents)[6] = ow * oh;
|
||||
((int *)mShape.contents)[7] = ob;
|
||||
((float *)mShape.contents)[8] = mSpatialScale;
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
|
@ -44,18 +46,22 @@ void MetalROIPooling::onEncode(const std::vector<Tensor *> &inputs, const std::v
|
|||
auto input = inputs[0], roi = inputs[1], output = outputs[0];
|
||||
int iw = input->width(), ih = input->height();
|
||||
int ow = output->width(), oh = output->height(), oz = UP_DIV(output->channel(), 4), ob = output->batch();
|
||||
|
||||
auto bandwidth = [context load:@"ROI_pooling" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)roi->deviceId())->getBuffer() offset:TensorUtils::getDescribe(roi)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
MetalBackend::setTensor(input, encoder, 0);
|
||||
MetalBackend::setTensor(roi, encoder, 1);
|
||||
MetalBackend::setTensor(output, encoder, 2);
|
||||
[encoder setBuffer:mShape offset:0 atIndex:3];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(ow, 16), UP_DIV(oh, 16), ob * oz) threadsPerThreadgroup:MTLSizeMake(16, 16, 1)];
|
||||
}
|
||||
|
||||
class MetalROIPoolingCreator : public MetalBackend::Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
|
||||
auto roi = inputs[1];
|
||||
if (TensorUtils::getDescribe(roi)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
|
||||
// Don't support old roipooling
|
||||
return nullptr;
|
||||
}
|
||||
return new MetalROIPooling(backend, op->main_as_RoiParameters()->spatialScale());
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ ErrorCode MetalRaster::onResize(const std::vector<Tensor *> &____inputs, const s
|
|||
fast = false;
|
||||
break;
|
||||
}
|
||||
if (!OpCommonUtils::canBlitFast(slice, output)) {
|
||||
if (!OpCommonUtils::canBlitFast(slice, output, 4, true)) {
|
||||
fast = false;
|
||||
break;
|
||||
}
|
||||
|
|
@ -239,7 +239,7 @@ ErrorCode MetalRaster::onResize(const std::vector<Tensor *> &____inputs, const s
|
|||
for (int v=0; v<iter.second.size(); ++v) {
|
||||
auto& oldr = des->regions[iter.second[v]];
|
||||
Tensor::InsideDescribe::Region slice;
|
||||
OpCommonUtils::turnToPackRegion(oldr, slice, output, 4);
|
||||
OpCommonUtils::turnToPackRegion(oldr, slice, output, 4, true);
|
||||
slice.dst.offset /= 4;
|
||||
slice.src.offset /= 4;
|
||||
writeSamplerInfo(infoP[v], slice);
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
//
|
||||
// MetalReLU.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/01/30.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifndef MetalReLU_hpp
|
||||
#define MetalReLU_hpp
|
||||
|
||||
#import "MetalExecution.hpp"
|
||||
#if MNN_METAL_ENABLED
|
||||
namespace MNN {
|
||||
|
||||
class MetalReLU : public MetalExecution {
|
||||
public:
|
||||
MetalReLU(Backend *backend, float slope);
|
||||
virtual ~MetalReLU() = default;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
|
||||
private:
|
||||
id<MTLBuffer> mSlope;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
#endif /* MNN_METAL_ENABLED */
|
||||
#endif /* MetalReLU_hpp */
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
//
|
||||
// MetalReLU.mm
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/01/30.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#import "backend/metal/MetalReLU.hpp"
|
||||
#import "backend/metal/MNNMetalContext.h"
|
||||
#import "core/Macro.h"
|
||||
#import "core/Macro.h"
|
||||
#import "backend/metal/MetalBackend.hpp"
|
||||
|
||||
#if MNN_METAL_ENABLED
|
||||
namespace MNN {
|
||||
|
||||
MetalReLU::MetalReLU(Backend *backend, float slope) : MetalExecution(backend) {
|
||||
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
|
||||
mSlope = [context newDeviceBuffer:sizeof(float) bytes:&slope access:CPUWriteOnly];
|
||||
}
|
||||
|
||||
void MetalReLU::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
auto context = (__bridge MNNMetalContext *)backend->context();
|
||||
|
||||
auto input = inputs[0], output = outputs[0];
|
||||
NSUInteger size = output->elementSize();
|
||||
auto simd = size % 4 == 0;
|
||||
if (simd) {
|
||||
size /= 4;
|
||||
}
|
||||
|
||||
MNN_ASSERT(mSlope.length == sizeof(float));
|
||||
auto bandwidth = [context load:simd ? @"relu_x4" : @"relu_x1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
|
||||
[encoder setBuffer:mSlope offset:0 atIndex:2];
|
||||
[context dispatchEncoder:encoder threads:{ size, 1, 1 } bandwidth:bandwidth];
|
||||
MNN_PRINT_ENCODER(context, encoder);
|
||||
}
|
||||
|
||||
class MetalReLUCreator : public MetalBackend::Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
|
||||
return new MetalReLU(backend, op->main_as_Relu()->slope());
|
||||
}
|
||||
};
|
||||
REGISTER_METAL_OP_CREATOR(MetalReLUCreator, OpType_ReLU);
|
||||
} // namespace MNN
|
||||
#endif /* MNN_METAL_ENABLED */
|
||||
|
|
@ -16,11 +16,12 @@ namespace MNN {
|
|||
|
||||
class MetalReLU6 : public MetalExecution {
|
||||
public:
|
||||
MetalReLU6(Backend *backend, float minValue, float maxValue);
|
||||
MetalReLU6(Backend *backend, float minValue, float maxValue, bool isRelu);
|
||||
virtual ~MetalReLU6() = default;
|
||||
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
||||
private:
|
||||
id<MTLBuffer> mConst;
|
||||
id<MTLComputePipelineState> mPipeline;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
|
|
|||
|
|
@ -15,33 +15,39 @@
|
|||
#if MNN_METAL_ENABLED
|
||||
namespace MNN {
|
||||
|
||||
MetalReLU6::MetalReLU6(Backend *backend, float minV, float maxV) : MetalExecution(backend) {
|
||||
MetalReLU6::MetalReLU6(Backend *backend, float minV, float maxV, bool isRelu) : MetalExecution(backend) {
|
||||
// For Relu use minV and slope
|
||||
auto metal = static_cast<MetalBackend *>(backend);
|
||||
auto context = (__bridge MNNMetalContext *)metal->context();
|
||||
mConst = [context newDeviceBuffer:4 * sizeof(float) access:CPUWriteOnly];
|
||||
mConst = [context newDeviceBuffer:4 * sizeof(float) access:CPUWriteOnly];
|
||||
((float*)mConst.contents)[0] = minV;
|
||||
((float*)mConst.contents)[1] = maxV;
|
||||
if (isRelu) {
|
||||
mPipeline = [context pipelineWithName:@"relu" fp16:metal->useFp16InsteadFp32()];
|
||||
} else {
|
||||
mPipeline = [context pipelineWithName:@"relu6" fp16:metal->useFp16InsteadFp32()];
|
||||
}
|
||||
}
|
||||
void MetalReLU6::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
||||
auto backend = static_cast<MetalBackend *>(this->backend());
|
||||
auto context = (__bridge MNNMetalContext *)backend->context();
|
||||
auto input = inputs[0], output = outputs[0];
|
||||
NSUInteger size = output->elementSize();
|
||||
auto simd = size % 4 == 0;
|
||||
if (simd) {
|
||||
size /= 4;
|
||||
}
|
||||
int size = output->elementSize();
|
||||
size = UP_DIV(size, 4);
|
||||
((int*)mConst.contents)[2] = size;
|
||||
((int*)mConst.contents)[3] = size;
|
||||
|
||||
auto bandwidth = [context load:simd ? @"relu6_x4" : @"relu6_x1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
|
||||
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
|
||||
[encoder setComputePipelineState:mPipeline];
|
||||
MetalBackend::setTensor(input, encoder, 0);
|
||||
MetalBackend::setTensor(output, encoder, 1);
|
||||
[encoder setBuffer:mConst offset:0 atIndex:2];
|
||||
[context dispatchEncoder:encoder threads:{ size, 1, 1 } bandwidth:bandwidth];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(size, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
|
||||
}
|
||||
|
||||
class MetalReLU6Creator : public MetalBackend::Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
|
||||
if (op->type() == OpType_ReLU) {
|
||||
return new MetalReLU6(backend, op->main_as_Relu()->slope(), 0.0f, true);
|
||||
}
|
||||
float minV = 0.0f;
|
||||
float maxV = 6.0f;
|
||||
if (nullptr != op->main()) {
|
||||
|
|
@ -49,9 +55,10 @@ public:
|
|||
minV = p->minValue();
|
||||
maxV = p->maxValue();
|
||||
}
|
||||
return new MetalReLU6(backend, minV, maxV);
|
||||
return new MetalReLU6(backend, minV, maxV, false);
|
||||
}
|
||||
};
|
||||
REGISTER_METAL_OP_CREATOR(MetalReLU6Creator, OpType_ReLU6);
|
||||
REGISTER_METAL_OP_CREATOR(MetalReLU6Creator, OpType_ReLU);
|
||||
} // namespace MNN
|
||||
#endif /* MNN_METAL_ENABLED */
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
namespace MNN {
|
||||
void ShaderMap::init() {
|
||||
mMaps.insert(std::make_pair("shader_MetalReLU6_metal", shader_MetalReLU6_metal));
|
||||
mMaps.insert(std::make_pair("shader_MetalReLU_metal", shader_MetalReLU_metal));
|
||||
mMaps.insert(std::make_pair("shader_MetalConvolutionDepthwise_metal", shader_MetalConvolutionDepthwise_metal));
|
||||
mMaps.insert(std::make_pair("shader_MetalConvolutionActivation_metal", shader_MetalConvolutionActivation_metal));
|
||||
mMaps.insert(std::make_pair("shader_MetalConvolution_metal", shader_MetalConvolution_metal));
|
||||
|
|
|
|||
|
|
@ -1,38 +1,10 @@
|
|||
struct tensor_shape {
|
||||
int size;
|
||||
int channel;
|
||||
int slice;
|
||||
int batch;
|
||||
int batch_slices;
|
||||
};
|
||||
|
||||
kernel void version_func_002(const device uchar *in [[buffer(0)]],
|
||||
device uchar *out [[buffer(1)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
// do nothing, just for verifying match between mnn and metallib
|
||||
}
|
||||
|
||||
kernel void copy_byte(const device uchar *in [[buffer(0)]],
|
||||
device uchar *out [[buffer(1)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
out[int(gid)] = in[int(gid)];
|
||||
}
|
||||
|
||||
kernel void copy_float(const device ftype *in [[buffer(0)]],
|
||||
device ftype *out [[buffer(1)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
out[int(gid)] = in[int(gid)];
|
||||
}
|
||||
|
||||
kernel void upcast_float(const device ftype *in [[buffer(0)]],
|
||||
device float *out [[buffer(1)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
out[int(gid)] = in[int(gid)];
|
||||
}
|
||||
kernel void downcast_float(const device float *in [[buffer(0)]],
|
||||
device ftype *out [[buffer(1)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
out[int(gid)] = in[int(gid)];
|
||||
}
|
||||
struct Limit {
|
||||
uint4 size;
|
||||
};
|
||||
|
|
@ -55,8 +27,8 @@ kernel void downcast_float4(const device float4 *in [[buffer(0)]],
|
|||
|
||||
template <typename IType, typename OType>
|
||||
static inline void template_NHWC_to_NC4HW4(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
|
||||
int b = gid.y / s.slice;
|
||||
int z = gid.y % s.slice;
|
||||
int b = gid.y % s.batch;
|
||||
int z = gid.y / s.batch;
|
||||
int c = z * 4;
|
||||
|
||||
auto off_in = in + b * s.size * s.channel + int(gid.x) * s.channel + c;
|
||||
|
|
@ -93,8 +65,8 @@ kernel void cvt_f_NHWC_to_NC4HW4(const device ftype *in [[buffer(0)]],
|
|||
|
||||
template <typename IType, typename OType>
|
||||
static inline void template_NC4HW4_to_NHWC(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
|
||||
int b = gid.y / s.slice;
|
||||
int z = gid.y % s.slice;
|
||||
int b = gid.y % s.batch;
|
||||
int z = gid.y / s.batch;
|
||||
int c = z * 4;
|
||||
auto off_in = in + int(gid.y) * s.size + int(gid.x);
|
||||
auto off_out = out + b * s.size * s.channel + int(gid.x) * s.channel + c;
|
||||
|
|
@ -132,8 +104,8 @@ kernel void cvt_f_NC4HW4_to_NHWC(const device ftype4 *in [[buffer(0)]],
|
|||
|
||||
template <typename IType, typename OType>
|
||||
static inline void template_NCHW_to_NC4HW4(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
|
||||
int b = gid.y / s.slice;
|
||||
int z = gid.y % s.slice;
|
||||
int b = gid.y % s.batch;
|
||||
int z = gid.y / s.batch;
|
||||
int c = z * 4;
|
||||
|
||||
auto off_in = in + (b * s.channel + c) * s.size + int(gid.x);
|
||||
|
|
@ -170,8 +142,8 @@ kernel void cvt_f_NCHW_to_NC4HW4(const device ftype *in [[buffer(0)]],
|
|||
|
||||
template <typename IType, typename OType>
|
||||
static inline void template_NC4HW4_to_NCHW(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
|
||||
int b = gid.y / s.slice;
|
||||
int z = gid.y % s.slice;
|
||||
int b = gid.y % s.batch;
|
||||
int z = gid.y / s.batch;
|
||||
int c = z * 4;
|
||||
|
||||
auto off_in = in + int(gid.y) * s.size + int(gid.x);
|
||||
|
|
@ -210,8 +182,8 @@ kernel void cvt_f_NC4HW4_to_NCHW(const device ftype4 *in [[buffer(0)]],
|
|||
template<typename IType, typename OType>
|
||||
static inline void template_NHWC_to_NCHW(const device IType* in,
|
||||
device OType* out, constant tensor_shape &s, uint2 gid) {
|
||||
int b = gid.y / s.slice;
|
||||
int c4 = gid.y % s.slice;
|
||||
int b = gid.y % s.batch;
|
||||
int c4 = gid.y / s.batch;
|
||||
|
||||
auto in_off = (b * s.size + gid.x) * s.channel + c4 * 4;
|
||||
auto out_off = (b * s.channel + c4 * 4) * s.size + gid.x;
|
||||
|
|
@ -238,8 +210,8 @@ kernel void upcast_f_NHWC_to_NCHW(const device ftype *in [[buffer(0)]],
|
|||
template<typename IType, typename OType>
|
||||
static inline void template_NCHW_to_NHWC(const device IType* in,
|
||||
device OType* out, constant tensor_shape &s, uint2 gid) {
|
||||
int b = gid.y / s.slice;
|
||||
int c4 = gid.y % s.slice;
|
||||
int b = gid.y % s.batch;
|
||||
int c4 = gid.y / s.batch;
|
||||
|
||||
auto in_off = (b * s.channel + c4 * 4) * s.size + gid.x;
|
||||
auto out_off = (b * s.size + gid.x) * s.channel + c4 * 4;
|
||||
|
|
|
|||
|
|
@ -102,9 +102,9 @@ kernel void conv(const device ftype4 *in [[buffer(0)]],
|
|||
offset_x += sx * cst.dilation_x;
|
||||
offset_y += sy * cst.dilation_y;
|
||||
|
||||
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_wt = wt + idx_c * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
|
||||
auto z_out = out + idx_b * cst.output_slice * cst.output_size + (int)idx_c * cst.output_size + (int)gid.y * cst.output_width + (int)gid.x;
|
||||
auto z_out = out + idx_b * cst.output_size + (int)idx_c * cst.batch * cst.output_size + (int)gid.y * cst.output_width + (int)gid.x;
|
||||
|
||||
int dilation_h = cst.input_width * cst.dilation_y;
|
||||
FLOAT4 result = FLOAT4(biasTerms[idx_c]);
|
||||
|
|
@ -112,7 +112,7 @@ kernel void conv(const device ftype4 *in [[buffer(0)]],
|
|||
for (auto y = 0; y < kh; y++) {
|
||||
for (auto x = 0; x < kw; x++) {
|
||||
auto wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x + x];
|
||||
auto in4 = z_in[z * cst.input_size + y * dilation_h + x * cst.dilation_x];
|
||||
auto in4 = z_in[z * cst.input_size * cst.batch + y * dilation_h + x * cst.dilation_x];
|
||||
result += FLOAT4(in4 * wt4);
|
||||
}
|
||||
}
|
||||
|
|
@ -141,15 +141,15 @@ kernel void convk3s1d1p1_w2z4(const device ftype4 *in [[buffer(0)]],
|
|||
int offset_x = (int)gid.x * 2 - cst.pad_x;
|
||||
int offset_y = (int)gid.y - cst.pad_y;
|
||||
|
||||
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_flt = wt + uz[0] * cst.input_slice * cst.kernel_size;
|
||||
auto z_out = out + idx_b * cst.output_slice * cst.output_size + uz[0] * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto z_out = out + idx_b * cst.output_size + uz[0] * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
int ws = cst.input_slice * cst.kernel_size;
|
||||
FLOAT4 result0 = 0, result1 = 0, result2 = 0, result3 = 0;
|
||||
FLOAT4 result4 = 0, result5 = 0, result6 = 0, result7 = 0;
|
||||
|
||||
for (auto z = 0; z < cst.input_slice; z++, z_flt += cst.kernel_size, z_in += cst.input_size) {
|
||||
for (auto z = 0; z < cst.input_slice; z++, z_flt += cst.kernel_size, z_in += (cst.input_size * cst.batch)) {
|
||||
auto in00 = (offset_x<0 || offset_y<0) ? (ftype4)0.f : *(z_in+0*cst.input_width+0);
|
||||
auto in01 = (offset_x+1>=cst.input_width || offset_y<0) ? (ftype4)0.f : *(z_in+0*cst.input_width+1);
|
||||
auto in02 = (offset_x+2>=cst.input_width || offset_y<0) ? (ftype4)0.f : *(z_in+0*cst.input_width+2);
|
||||
|
|
@ -228,19 +228,19 @@ kernel void conv_s1d1p0_w2(const device ftype4 *in [[buffer(0)]],
|
|||
|
||||
bool valid = (idx_w + 1 < cst.output_width);
|
||||
|
||||
auto z_in = in + idx_b * cst.input_slice * cst.input_size + idx_h * cst.input_width + idx_w;
|
||||
auto z_in = in + idx_b * cst.input_size + idx_h * cst.input_width + idx_w;
|
||||
auto z_wt = wt + idx_c * cst.input_slice * cst.kernel_size;
|
||||
auto z_out = out + idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto z_out = out + idx_b * cst.output_size + idx_c * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
FLOAT4 result0 = FLOAT4(biasTerms[idx_c]);
|
||||
FLOAT4 result1 = result0;
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
for (auto y = 0; y < cst.kernel_y; y++) {
|
||||
auto wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x];
|
||||
auto in4_0 = z_in[z * cst.input_size + y * cst.input_width];
|
||||
auto in4_0 = z_in[z * cst.batch * cst.input_size + y * cst.input_width];
|
||||
result0 += FLOAT4(in4_0 * wt4);
|
||||
for (auto x = 1; x < cst.kernel_x; x++) {
|
||||
in4_0 = z_in[z * cst.input_size + y * cst.input_width + x];
|
||||
in4_0 = z_in[z * cst.batch * cst.input_size + y * cst.input_width + x];
|
||||
result1 += FLOAT4(in4_0 * wt4);
|
||||
wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x + x];
|
||||
result0 += FLOAT4(in4_0 * wt4);
|
||||
|
|
@ -271,9 +271,9 @@ kernel void conv_s1d1p0_w4(const device ftype4 *in [[buffer(0)]],
|
|||
int3 uz = idx_w + int3(1, 2, 3);
|
||||
bool3 valids = uz.xyz < cst.output_width;
|
||||
|
||||
auto z_in = in + idx_b * cst.input_slice * cst.input_size + idx_h * cst.input_width + idx_w;
|
||||
auto z_in = in + idx_b * cst.input_size + idx_h * cst.input_width + idx_w;
|
||||
auto z_wt = wt + idx_c * cst.input_slice * cst.kernel_size;
|
||||
auto z_out = out + idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto z_out = out + idx_b * cst.output_size + idx_c * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
FLOAT4 result0 = FLOAT4(biasTerms[idx_c]);
|
||||
FLOAT4 result1 = result0;
|
||||
|
|
@ -286,7 +286,7 @@ kernel void conv_s1d1p0_w4(const device ftype4 *in [[buffer(0)]],
|
|||
auto wt4_1 = wt_base[1];
|
||||
auto wt4_2 = wt_base[2];
|
||||
|
||||
auto z_in_base = z_in + z * cst.input_size + y * cst.input_width;
|
||||
auto z_in_base = z_in + z * cst.batch * cst.input_size + y * cst.input_width;
|
||||
auto in4_0 = z_in_base[0];
|
||||
result0 += FLOAT4(in4_0 * wt4_0);
|
||||
|
||||
|
|
@ -346,14 +346,14 @@ kernel void conv_z4(const device ftype4 *in [[buffer(0)]],
|
|||
offset_x += sx * cst.dilation_x;
|
||||
offset_y += sy * cst.dilation_y;
|
||||
|
||||
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_wt = wt + uz[0] * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
|
||||
auto z_out = out + idx_b * cst.output_slice * cst.output_size + uz[0] * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto z_out = out + idx_b * cst.output_size + uz[0] * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
int ws = cst.input_slice * cst.kernel_size;
|
||||
int dilation_h = cst.input_width * cst.dilation_y;
|
||||
FLOAT4 result0 = 0, result1 = 0, result2 = 0, result3 = 0;
|
||||
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size) {
|
||||
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size * cst.batch) {
|
||||
for (auto y = 0; y < kh; y++) {
|
||||
for (auto x = 0; x < kw; x++) {
|
||||
auto x_wt = z_wt + y * cst.kernel_x + x;
|
||||
|
|
@ -400,14 +400,14 @@ kernel void conv_z2(const device ftype4 *in [[buffer(0)]],
|
|||
offset_x += sx * cst.dilation_x;
|
||||
offset_y += sy * cst.dilation_y;
|
||||
|
||||
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
|
||||
auto z_wt = wt + uz[0] * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
|
||||
auto z_out = out + idx_b * cst.output_slice * cst.output_size + uz[0] * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto z_out = out + idx_b * cst.output_size + uz[0] * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
int ws = cst.input_slice * cst.kernel_size;
|
||||
int dilation_h = cst.input_width * cst.dilation_y;
|
||||
FLOAT4 result0 = 0, result1 = 0;
|
||||
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size) {
|
||||
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size * cst.batch) {
|
||||
for (auto y = 0; y < kh; y++) {
|
||||
for (auto x = 0; x < kw; x++) {
|
||||
auto x_wt = z_wt + y * cst.kernel_x + x;
|
||||
|
|
@ -418,5 +418,5 @@ kernel void conv_z2(const device ftype4 *in [[buffer(0)]],
|
|||
}
|
||||
}
|
||||
/* true */ *z_out = activate(ftype4(result0 + FLOAT4(biasTerms[uz[0]])), cst.activation);
|
||||
if (valids) { z_out += cst.output_size; *z_out = activate(ftype4(result1 + FLOAT4(biasTerms[uz[1]])), cst.activation); }
|
||||
if (valids) { z_out += cst.output_size * cst.batch; *z_out = activate(ftype4(result1 + FLOAT4(biasTerms[uz[1]])), cst.activation); }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,36 +13,6 @@ struct conv1x1_constants {
|
|||
conv_activation_type activation;
|
||||
};
|
||||
|
||||
kernel void conv1x1_w1h1(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant conv1x1_constants& cst [[buffer(2)]],
|
||||
const device ftype4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
|
||||
int idx_w = gid.x;
|
||||
int idx_h = gid.y;
|
||||
int idx_c = gid.z % cst.output_slice;
|
||||
int idx_b = gid.z / cst.output_slice;
|
||||
|
||||
auto xy_wt = wt + idx_c * cst.input_slice;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto biasValue = FLOAT4(biasTerms[idx_c]);
|
||||
FLOAT4 result0 = biasValue;
|
||||
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
auto in40 = xy_in0[0];
|
||||
auto w = xy_wt[z];
|
||||
|
||||
result0 += FLOAT4(in40 * w);
|
||||
xy_in0 += cst.input_size;
|
||||
}
|
||||
|
||||
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
|
||||
}
|
||||
|
||||
kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant conv1x1_constants& cst [[buffer(2)]],
|
||||
|
|
@ -54,8 +24,8 @@ kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]],
|
|||
int rx = gid.x * CONV_UNROLL;
|
||||
int uz = gid.y;
|
||||
auto xy_wt = wt + uz * cst.input_slice;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
|
||||
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
|
||||
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.output_size * cst.batch + rx;
|
||||
auto biasValue = FLOAT4(biasTerms[uz]);
|
||||
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
|
||||
int computeSize = min(cst.output_size - rx, CONV_UNROLL);
|
||||
|
|
@ -71,7 +41,7 @@ kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]],
|
|||
result1 += FLOAT4(in41 * w);
|
||||
result2 += FLOAT4(in42 * w);
|
||||
result3 += FLOAT4(in43 * w);
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
|
||||
|
|
@ -86,20 +56,19 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in [[buffer(0)]],
|
|||
const device MNN::char4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
const device float4 *dequantScale [[buffer(5)]],
|
||||
const device float4 *dequantBias [[buffer(6)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;
|
||||
|
||||
int rx = gid.x * CONV_UNROLL;
|
||||
int uz = gid.y;
|
||||
auto xy_wt = wt + uz * cst.input_slice;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
|
||||
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
|
||||
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.output_size * cst.batch + rx;
|
||||
auto biasValue = FLOAT4(biasTerms[uz]);
|
||||
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
|
||||
int computeSize = min(cst.output_size - rx, CONV_UNROLL);
|
||||
auto scale = FLOAT4(dequantScale[uz]);
|
||||
auto dequant_bias = FLOAT4(dequantBias[uz]);
|
||||
auto dequant_bias = FLOAT4(dequantScale[uz + cst.output_slice]);
|
||||
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
auto in40 = (FLOAT4)*xy_in0;
|
||||
|
|
@ -112,20 +81,14 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in [[buffer(0)]],
|
|||
FLOAT4x4 w_fp32 = FLOAT4x4(FLOAT4(w[0]), FLOAT4(w[1]), FLOAT4(w[2]), FLOAT4(w[3]));
|
||||
FLOAT4x4 w_dequant;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
FLOAT4 w4 = w_fp32[i];
|
||||
FLOAT4 res;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float wf = w4[j] * scale[i] + dequant_bias[i];
|
||||
res[j] = wf;
|
||||
}
|
||||
w_dequant[i] = res;
|
||||
w_dequant[i] = w_fp32[i] * scale[i] + dequant_bias[i];
|
||||
}
|
||||
|
||||
result0 += FLOAT4(in40 * w_dequant);
|
||||
result1 += FLOAT4(in41 * w_dequant);
|
||||
result2 += FLOAT4(in42 * w_dequant);
|
||||
result3 += FLOAT4(in43 * w_dequant);
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
/* true */
|
||||
|
|
@ -141,20 +104,19 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]],
|
|||
const device MNN::uchar4x2 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
const device float4 *dequantScale [[buffer(5)]],
|
||||
const device float4 *dequantBias [[buffer(6)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;
|
||||
|
||||
int rx = gid.x * CONV_UNROLL;
|
||||
int uz = gid.y;
|
||||
auto xy_wt = wt + uz * cst.input_slice;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
|
||||
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
|
||||
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.output_size * cst.batch + rx;
|
||||
auto biasValue = FLOAT4(biasTerms[uz]);
|
||||
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
|
||||
int computeSize = min(cst.output_size - rx, CONV_UNROLL);
|
||||
auto scale = FLOAT4(dequantScale[uz]);
|
||||
auto dequant_bias = FLOAT4(dequantBias[uz]);
|
||||
auto dequant_bias = FLOAT4(dequantScale[uz + cst.output_slice]);
|
||||
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
auto in40 = (FLOAT4)*xy_in0;
|
||||
|
|
@ -169,11 +131,7 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]],
|
|||
for (int i = 0; i < 4; ++i) {
|
||||
// ftype4 w4 = ftype4(w_fp32[i]);
|
||||
FLOAT4 w4 = FLOAT4((float)(w_int4[i][0] >> 4) - 8, (float)(w_int4[i][0] & 15) - 8, (float)(w_int4[i][1] >> 4) - 8, (float)(w_int4[i][1] & 15) - 8);
|
||||
FLOAT4 res;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float wf = w4[j] * scale[i] + dequant_bias[i];
|
||||
res[j] = wf;
|
||||
}
|
||||
FLOAT4 res = w4 * scale[i] + dequant_bias[i];
|
||||
w_dequant[i] = res;
|
||||
}
|
||||
|
||||
|
|
@ -181,7 +139,7 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]],
|
|||
result1 += FLOAT4(in41 * w_dequant);
|
||||
result2 += FLOAT4(in42 * w_dequant);
|
||||
result3 += FLOAT4(in43 * w_dequant);
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
/* true */
|
||||
|
|
@ -213,9 +171,9 @@ kernel void conv1x1_g1z8(const device ftype4 *in [[buffer(0)]],
|
|||
int rx = gid.x * CONV_UNROLL_L;
|
||||
int uz = gid.y;
|
||||
auto xy_wt = wt + uz * cst.input_slice;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
|
||||
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
|
||||
|
||||
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
|
||||
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.batch * cst.output_size + rx;
|
||||
auto biasValue = FLOAT4(biasTerms[uz]);
|
||||
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
|
||||
FLOAT4 result4 = biasValue, result5 = biasValue, result6 = biasValue, result7 = biasValue;
|
||||
|
|
@ -241,7 +199,7 @@ kernel void conv1x1_g1z8(const device ftype4 *in [[buffer(0)]],
|
|||
result5 += FLOAT4(in45 * w);
|
||||
result6 += FLOAT4(in46 * w);
|
||||
result7 += FLOAT4(in47 * w);
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
|
||||
|
|
@ -254,84 +212,23 @@ kernel void conv1x1_g1z8(const device ftype4 *in [[buffer(0)]],
|
|||
if (computeSize > 7) {xy_out[7] = activate(ftype4(result7), cst.activation); }
|
||||
}
|
||||
|
||||
|
||||
kernel void conv1x1_w4h2(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant conv1x1_constants& cst [[buffer(2)]],
|
||||
const device ftype4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * 4 >= cst.output_width || (int)gid.y * 2 >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
|
||||
|
||||
int idx_w = gid.x << 2;
|
||||
int idx_h = gid.y << 1;
|
||||
int idx_c = gid.z % cst.output_slice;
|
||||
int idx_b = gid.z / cst.output_slice;
|
||||
|
||||
auto xy_wt = wt + idx_c * cst.input_slice;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto biasValue = FLOAT4(biasTerms[idx_c]);
|
||||
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
|
||||
FLOAT4 result4 = biasValue, result5 = biasValue, result6 = biasValue, result7 = biasValue;
|
||||
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
auto in40 = xy_in0[0];
|
||||
auto in41 = xy_in0[1];
|
||||
auto in42 = xy_in0[2];
|
||||
auto in43 = xy_in0[3];
|
||||
auto in44 = xy_in0[cst.output_width+0];
|
||||
auto in45 = xy_in0[cst.output_width+1];
|
||||
auto in46 = xy_in0[cst.output_width+2];
|
||||
auto in47 = xy_in0[cst.output_width+3];
|
||||
|
||||
auto w = xy_wt[z];
|
||||
|
||||
result0 += FLOAT4(in40 * w);
|
||||
result1 += FLOAT4(in41 * w);
|
||||
result2 += FLOAT4(in42 * w);
|
||||
result3 += FLOAT4(in43 * w);
|
||||
result4 += FLOAT4(in44 * w);
|
||||
result5 += FLOAT4(in45 * w);
|
||||
result6 += FLOAT4(in46 * w);
|
||||
result7 += FLOAT4(in47 * w);
|
||||
xy_in0 += cst.input_size;
|
||||
}
|
||||
|
||||
int widthSize = min(cst.output_width - idx_w, 4);
|
||||
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
|
||||
if (widthSize > 1) {xy_out[1] = activate(ftype4(result1), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[2] = activate(ftype4(result2), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[3] = activate(ftype4(result3), cst.activation); }
|
||||
|
||||
int heightSize = min(cst.output_height - idx_h, 2);
|
||||
if(heightSize > 1) {
|
||||
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result4), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result5), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[cst.output_width+2] = activate(ftype4(result6), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[cst.output_width+3] = activate(ftype4(result7), cst.activation); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
kernel void conv1x1_w4h4(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant conv1x1_constants& cst [[buffer(2)]],
|
||||
const device ftype4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * 4 >= cst.output_width || (int)gid.y * 4 >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
|
||||
if ((int)gid.x * 16 >= cst.output_width || (int)gid.y >= cst.batch * cst.output_slice) return;
|
||||
|
||||
int idx_w = gid.x << 2;
|
||||
int idx_h = gid.y << 2;
|
||||
int idx_c = gid.z % cst.output_slice;
|
||||
int idx_b = gid.z / cst.output_slice;
|
||||
int idx_w = gid.x << 4;
|
||||
int idx_h = 0;
|
||||
int idx_c = gid.y / cst.batch;
|
||||
int idx_b = gid.y % cst.batch;
|
||||
|
||||
auto xy_wt = wt + idx_c * cst.input_slice;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto xy_out = out + (int)idx_b * cst.output_size + idx_c * cst.output_size * cst.batch + idx_h * cst.output_width + idx_w;
|
||||
auto biasValue = FLOAT4(biasTerms[idx_c]);
|
||||
FLOAT4 result00 = biasValue, result01 = biasValue, result02 = biasValue, result03 = biasValue;
|
||||
FLOAT4 result10 = biasValue, result11 = biasValue, result12 = biasValue, result13 = biasValue;
|
||||
|
|
@ -343,19 +240,19 @@ kernel void conv1x1_w4h4(const device ftype4 *in [[buffer(0)]],
|
|||
auto in01 = xy_in0[1];
|
||||
auto in02 = xy_in0[2];
|
||||
auto in03 = xy_in0[3];
|
||||
auto in10 = xy_in0[cst.output_width+0];
|
||||
auto in11 = xy_in0[cst.output_width+1];
|
||||
auto in12 = xy_in0[cst.output_width+2];
|
||||
auto in13 = xy_in0[cst.output_width+3];
|
||||
auto in10 = xy_in0[4];
|
||||
auto in11 = xy_in0[5];
|
||||
auto in12 = xy_in0[6];
|
||||
auto in13 = xy_in0[7];
|
||||
|
||||
auto in20 = xy_in0[cst.output_width+cst.output_width+0];
|
||||
auto in21 = xy_in0[cst.output_width+cst.output_width+1];
|
||||
auto in22 = xy_in0[cst.output_width+cst.output_width+2];
|
||||
auto in23 = xy_in0[cst.output_width+cst.output_width+3];
|
||||
auto in30 = xy_in0[cst.output_width+cst.output_width+cst.output_width+0];
|
||||
auto in31 = xy_in0[cst.output_width+cst.output_width+cst.output_width+1];
|
||||
auto in32 = xy_in0[cst.output_width+cst.output_width+cst.output_width+2];
|
||||
auto in33 = xy_in0[cst.output_width+cst.output_width+cst.output_width+3];
|
||||
auto in20 = xy_in0[8];
|
||||
auto in21 = xy_in0[9];
|
||||
auto in22 = xy_in0[10];
|
||||
auto in23 = xy_in0[11];
|
||||
auto in30 = xy_in0[12];
|
||||
auto in31 = xy_in0[13];
|
||||
auto in32 = xy_in0[14];
|
||||
auto in33 = xy_in0[15];
|
||||
|
||||
|
||||
auto w = xy_wt[z];
|
||||
|
|
@ -378,34 +275,26 @@ kernel void conv1x1_w4h4(const device ftype4 *in [[buffer(0)]],
|
|||
result32 += FLOAT4(in32 * w);
|
||||
result33 += FLOAT4(in33 * w);
|
||||
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
int widthSize = min(cst.output_width - idx_w, 4);
|
||||
int widthSize = min(cst.output_width - idx_w, 16);
|
||||
/* true */ *xy_out = activate(ftype4(result00), cst.activation);
|
||||
if (widthSize > 1) {xy_out[1] = activate(ftype4(result01), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[2] = activate(ftype4(result02), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[3] = activate(ftype4(result03), cst.activation); }
|
||||
|
||||
int heightSize = min(cst.output_height - idx_h, 4);
|
||||
if(heightSize > 1) {
|
||||
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result10), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result11), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[cst.output_width+2] = activate(ftype4(result12), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[cst.output_width+3] = activate(ftype4(result13), cst.activation); }
|
||||
}
|
||||
if(heightSize > 2) {
|
||||
/* true */ {xy_out[cst.output_width+cst.output_width+0] = activate(ftype4(result20), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_width+cst.output_width+1] = activate(ftype4(result21), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[cst.output_width+cst.output_width+2] = activate(ftype4(result22), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[cst.output_width+cst.output_width+3] = activate(ftype4(result23), cst.activation); }
|
||||
}
|
||||
if(heightSize > 3) {
|
||||
/* true */ {xy_out[cst.output_width+cst.output_width+cst.output_width+0] = activate(ftype4(result30), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_width+cst.output_width+cst.output_width+1] = activate(ftype4(result31), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[cst.output_width+cst.output_width+cst.output_width+2] = activate(ftype4(result32), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[cst.output_width+cst.output_width+cst.output_width+3] = activate(ftype4(result33), cst.activation); }
|
||||
}
|
||||
if (widthSize > 4) {xy_out[4] = activate(ftype4(result10), cst.activation); }
|
||||
if (widthSize > 5) {xy_out[5] = activate(ftype4(result11), cst.activation); }
|
||||
if (widthSize > 6) {xy_out[6] = activate(ftype4(result12), cst.activation); }
|
||||
if (widthSize > 7) {xy_out[7] = activate(ftype4(result13), cst.activation); }
|
||||
if (widthSize > 8) {xy_out[8] = activate(ftype4(result20), cst.activation); }
|
||||
if (widthSize > 9) {xy_out[9] = activate(ftype4(result21), cst.activation); }
|
||||
if (widthSize > 10) {xy_out[10] = activate(ftype4(result22), cst.activation); }
|
||||
if (widthSize > 11) {xy_out[11] = activate(ftype4(result23), cst.activation); }
|
||||
if (widthSize > 12) {xy_out[12] = activate(ftype4(result30), cst.activation); }
|
||||
if (widthSize > 13) {xy_out[13] = activate(ftype4(result31), cst.activation); }
|
||||
if (widthSize > 14) {xy_out[14] = activate(ftype4(result32), cst.activation); }
|
||||
if (widthSize > 15) {xy_out[15] = activate(ftype4(result33), cst.activation); }
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -415,19 +304,19 @@ kernel void conv1x1_w2c2(const device ftype4 *in [[buffer(0)]],
|
|||
const device ftype4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z * 2 >= cst.batch * cst.output_slice) return;
|
||||
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y * 2 >= cst.batch * cst.output_slice) return;
|
||||
|
||||
int channel_pack = (cst.output_channel + 7) >> 3;
|
||||
int idx_w = gid.x << 1;
|
||||
int idx_h = gid.y;
|
||||
int idx_c = (gid.z % channel_pack) << 1;
|
||||
int idx_b = gid.z / channel_pack;
|
||||
int idx_h = 0;
|
||||
int idx_c = (gid.y % channel_pack) << 1;
|
||||
int idx_b = gid.y / channel_pack;
|
||||
|
||||
if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;
|
||||
auto xy_wt = wt + idx_c * cst.input_slice;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto xy_out = out + (int)idx_b * cst.output_size + idx_c * cst.output_size * cst.batch + idx_h * cst.output_width + idx_w;
|
||||
auto biasValue0 = FLOAT4(biasTerms[idx_c]);
|
||||
auto biasValue1 = FLOAT4(biasTerms[idx_c+1]);
|
||||
|
||||
|
|
@ -445,7 +334,7 @@ kernel void conv1x1_w2c2(const device ftype4 *in [[buffer(0)]],
|
|||
result1 += FLOAT4(in41 * w0);
|
||||
result4 += FLOAT4(in40 * w1);
|
||||
result5 += FLOAT4(in41 * w1);
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
int widthSize = min(cst.output_width - idx_w, 2);
|
||||
|
|
@ -454,79 +343,30 @@ kernel void conv1x1_w2c2(const device ftype4 *in [[buffer(0)]],
|
|||
|
||||
int channelSize = min(cst.output_slice - idx_c, 2);
|
||||
if(channelSize > 1) {
|
||||
/* true */ {xy_out[cst.output_size+0] = activate(ftype4(result4), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_size+1] = activate(ftype4(result5), cst.activation); }
|
||||
/* true */ {xy_out[cst.output_size * cst.batch +0] = activate(ftype4(result4), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_size * cst.batch +1] = activate(ftype4(result5), cst.activation); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
kernel void conv1x1_w2h2(const device ftype4 *in [[buffer(0)]],
|
||||
kernel void conv1x1_w4c2(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant conv1x1_constants& cst [[buffer(2)]],
|
||||
const device ftype4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y * 2 >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
|
||||
|
||||
int idx_w = gid.x << 1;
|
||||
int idx_h = gid.y << 1;
|
||||
int idx_c = gid.z % cst.output_slice;
|
||||
int idx_b = gid.z / cst.output_slice;
|
||||
|
||||
auto xy_wt = wt + idx_c * cst.input_slice;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto biasValue = FLOAT4(biasTerms[idx_c]);
|
||||
FLOAT4 result0 = biasValue, result1 = biasValue;
|
||||
FLOAT4 result4 = biasValue, result5 = biasValue;
|
||||
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
auto in40 = xy_in0[0];
|
||||
auto in41 = xy_in0[1];
|
||||
auto in44 = xy_in0[cst.output_width+0];
|
||||
auto in45 = xy_in0[cst.output_width+1];
|
||||
|
||||
auto w = xy_wt[z];
|
||||
|
||||
result0 += FLOAT4(in40 * w);
|
||||
result1 += FLOAT4(in41 * w);
|
||||
result4 += FLOAT4(in44 * w);
|
||||
result5 += FLOAT4(in45 * w);
|
||||
xy_in0 += cst.input_size;
|
||||
}
|
||||
|
||||
int widthSize = min(cst.output_width - idx_w, 2);
|
||||
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
|
||||
if (widthSize > 1) {xy_out[1] = activate(ftype4(result1), cst.activation); }
|
||||
|
||||
int heightSize = min(cst.output_height - idx_h, 2);
|
||||
if(heightSize > 1) {
|
||||
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result4), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result5), cst.activation); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
kernel void conv1x1_w2h2c2(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant conv1x1_constants& cst [[buffer(2)]],
|
||||
const device ftype4x4 *wt [[buffer(3)]],
|
||||
const device ftype4 *biasTerms [[buffer(4)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y * 2 >= cst.output_height || (int)gid.z * 2 >= cst.batch * cst.output_slice) return;
|
||||
if ((int)gid.x * 4 >= cst.output_width || (int)gid.y * 2 >= cst.batch * cst.output_slice) return;
|
||||
|
||||
int channel_pack = (cst.output_channel + 7) >> 3;
|
||||
int idx_w = gid.x << 1;
|
||||
int idx_h = gid.y << 1;
|
||||
int idx_c = (gid.z % channel_pack) << 1;
|
||||
int idx_b = gid.z / channel_pack;
|
||||
int idx_w = gid.x << 2;
|
||||
int idx_h = 0;
|
||||
int idx_c = (gid.y % channel_pack) << 1;
|
||||
int idx_b = gid.y / channel_pack;
|
||||
|
||||
if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;
|
||||
auto xy_wt = wt + idx_c * cst.input_slice;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
auto xy_in0 = in + (int)idx_b * cst.input_size + idx_h * cst.output_width + idx_w;
|
||||
|
||||
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
|
||||
auto xy_out = out + (int)idx_b * cst.output_size + idx_c * cst.output_size * cst.batch + idx_h * cst.output_width + idx_w;
|
||||
auto biasValue0 = FLOAT4(biasTerms[idx_c]);
|
||||
auto biasValue1 = FLOAT4(biasTerms[idx_c+1]);
|
||||
|
||||
|
|
@ -537,8 +377,8 @@ kernel void conv1x1_w2h2c2(const device ftype4 *in [[buffer(0)]],
|
|||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
auto in40 = xy_in0[0];
|
||||
auto in41 = xy_in0[1];
|
||||
auto in44 = xy_in0[cst.output_width+0];
|
||||
auto in45 = xy_in0[cst.output_width+1];
|
||||
auto in44 = xy_in0[2];
|
||||
auto in45 = xy_in0[3];
|
||||
|
||||
auto w0 = xy_wt[z];
|
||||
auto w1 = xy_wt[cst.input_slice+z];
|
||||
|
|
@ -551,27 +391,20 @@ kernel void conv1x1_w2h2c2(const device ftype4 *in [[buffer(0)]],
|
|||
result3 += FLOAT4(in41 * w1);
|
||||
result6 += FLOAT4(in44 * w1);
|
||||
result7 += FLOAT4(in45 * w1);
|
||||
xy_in0 += cst.input_size;
|
||||
xy_in0 += cst.input_size * cst.batch;
|
||||
}
|
||||
|
||||
int widthSize = min(cst.output_width - idx_w, 2);
|
||||
int widthSize = min(cst.output_width - idx_w, 4);
|
||||
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
|
||||
if (widthSize > 1) {xy_out[1] = activate(ftype4(result1), cst.activation); }
|
||||
|
||||
int heightSize = min(cst.output_height - idx_h, 2);
|
||||
if(heightSize > 1) {
|
||||
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result4), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result5), cst.activation); }
|
||||
}
|
||||
|
||||
if (widthSize > 2) {xy_out[2] = activate(ftype4(result4), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[3] = activate(ftype4(result5), cst.activation); }
|
||||
|
||||
int channelSize = min(cst.output_slice - idx_c, 2);
|
||||
if(channelSize > 1) {
|
||||
/* true */ xy_out[cst.output_size] = activate(ftype4(result2), cst.activation);
|
||||
if (widthSize > 1) {xy_out[cst.output_size+1] = activate(ftype4(result3), cst.activation); }
|
||||
|
||||
if(heightSize > 1) {
|
||||
/* true */ {xy_out[cst.output_size+cst.output_width+0] = activate(ftype4(result6), cst.activation); }
|
||||
if (widthSize > 1) {xy_out[cst.output_size+cst.output_width+1] = activate(ftype4(result7), cst.activation); }
|
||||
}
|
||||
/* true */ xy_out[cst.output_size * cst.batch] = activate(ftype4(result2), cst.activation);
|
||||
if (widthSize > 1) {xy_out[cst.output_size * cst.batch +1] = activate(ftype4(result3), cst.activation); }
|
||||
if (widthSize > 2) {xy_out[cst.output_size * cst.batch +2] = activate(ftype4(result6), cst.activation); }
|
||||
if (widthSize > 3) {xy_out[cst.output_size * cst.batch +3] = activate(ftype4(result7), cst.activation); }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ kernel void conv_depthwise(const device ftype4 *in [[buffer(0)]],
|
|||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice * cst.batch) return;
|
||||
|
||||
int oz = gid.z % cst.slice;
|
||||
int oz = gid.z / cst.batch;
|
||||
int offset_x = (int)gid.x * cst.stride_x - cst.pad_x;
|
||||
int offset_y = (int)gid.y * cst.stride_y - cst.pad_y;
|
||||
int sx = max(0, (UP_DIV(-offset_x, cst.dilation_x)));
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ kernel void deconv(const device ftype4 *in [[buffer(0)]],
|
|||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
|
||||
|
||||
int b = gid.z / cst.output_slice;
|
||||
int o = gid.z % cst.output_slice;
|
||||
int b = gid.z % cst.batch;
|
||||
int o = gid.z / cst.batch;
|
||||
FLOAT4 result = FLOAT4(biasTerms[o]);
|
||||
|
||||
int oy = (int)gid.y + cst.pad_y;
|
||||
|
|
@ -56,12 +56,12 @@ kernel void deconv(const device ftype4 *in [[buffer(0)]],
|
|||
int min_ix = (ox - max_kx * cst.dilation_x) / cst.stride_x;
|
||||
|
||||
auto o_wt = wt + o * cst.input_slice * cst.kernel_size;
|
||||
auto b_in = in + b * cst.input_slice * cst.input_size;
|
||||
auto b_in = in + b * cst.input_size;
|
||||
for (auto z = 0; z < cst.input_slice; z++) {
|
||||
for (auto ky = max_ky, iy = min_iy; ky >= min_ky; ky -= cst.delta_ky, iy += cst.delta_iy) {
|
||||
for (auto kx = max_kx, ix = min_ix; kx >= min_kx; kx -= cst.delta_kx, ix += cst.delta_ix) {
|
||||
auto wt4 = o_wt[z * cst.kernel_size + ky * cst.kernel_x + kx];
|
||||
auto in4 = b_in[z * cst.input_size + iy * cst.input_width + ix];
|
||||
auto in4 = b_in[z * cst.input_size * cst.batch + iy * cst.input_width + ix];
|
||||
result += FLOAT4(in4 * wt4);
|
||||
}
|
||||
}
|
||||
|
|
@ -78,7 +78,7 @@ kernel void deconv_depthwise(const device ftype4 *in [[buffer(0)]],
|
|||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
|
||||
|
||||
FLOAT4 result = FLOAT4(biasTerms[(int)(gid.z % cst.input_slice)]);
|
||||
FLOAT4 result = FLOAT4(biasTerms[(int)(gid.z / cst.batch)]);
|
||||
|
||||
int oy = (int)gid.y + cst.pad_y;
|
||||
int ox = (int)gid.x + cst.pad_x;
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ kernel void prelu_slopes(const device ftype4 *in [[buffer(0)]],
|
|||
uint3 gid [[thread_position_in_grid]]) { // size, slice, batch
|
||||
if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;
|
||||
|
||||
int z = gid.z * s.slice + gid.y;
|
||||
int z = gid.z + gid.y * s.batch;
|
||||
auto v4 = in[z * s.size + int(gid.x)];
|
||||
out[z * s.size + int(gid.x)] = select(v4, ftype4(slope[int(gid.y)]) * v4, signbit(v4));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ struct ROI_shape {
|
|||
int input_width;
|
||||
int input_height;
|
||||
int input_size;
|
||||
int input_batch;
|
||||
int output_width;
|
||||
int output_height;
|
||||
int output_size;
|
||||
int slices;
|
||||
int batch;
|
||||
float spatial_scale;
|
||||
};
|
||||
kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
|
||||
|
|
@ -15,10 +16,10 @@ kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
|
|||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;
|
||||
|
||||
int ob = gid.z / s.slices;
|
||||
int iz = gid.z % s.slices;
|
||||
int ob = gid.z % s.batch;
|
||||
int iz = gid.z / s.batch;
|
||||
|
||||
auto b_roi = roi + ob * 8; // roundup(5, 4) = 8
|
||||
auto b_roi = roi + ob * 5;
|
||||
int ib = int(b_roi[0]);
|
||||
int x1 = round(float(b_roi[1]) * s.spatial_scale);
|
||||
int y1 = round(float(b_roi[2]) * s.spatial_scale);
|
||||
|
|
@ -27,8 +28,8 @@ kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
|
|||
|
||||
int roi_w = max(x2 - x1 + 1, 1);
|
||||
int roi_h = max(y2 - y1 + 1, 1);
|
||||
auto bin_size_w = (ftype)roi_w / s.output_width;
|
||||
auto bin_size_h = (ftype)roi_h / s.output_height;
|
||||
float bin_size_w = (float)roi_w / (float)s.output_width;
|
||||
float bin_size_h = (float)roi_h / (float)s.output_height;
|
||||
|
||||
int w_start = clamp(x1 + (int)floor(gid.x * bin_size_w) , 0, s.input_width);
|
||||
int w_end = clamp(x1 + (int)ceil((gid.x + 1) * bin_size_w), 0, s.input_width);
|
||||
|
|
@ -36,7 +37,7 @@ kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
|
|||
int h_end = clamp(y1 + (int)ceil((gid.y + 1) * bin_size_h), 0, s.input_height);
|
||||
|
||||
int is_empty = (h_end <= h_start) || (w_end <= w_start);
|
||||
auto z_in = in + (ib * s.slices + iz) * s.input_size;
|
||||
auto z_in = in + (ib + iz * s.input_batch) * s.input_size;
|
||||
auto max4 = is_empty ? 0 : z_in[h_start * s.input_width + w_start];
|
||||
for (int y = h_start; y < h_end; y++) {
|
||||
auto y_in = z_in + y * s.input_width;
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
kernel void relu_x1(const device ftype *in [[buffer(0)]],
|
||||
device ftype *out [[buffer(1)]],
|
||||
constant float &slope [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
auto value = in[int(gid)];
|
||||
out[int(gid)] = fmax(value, (ftype)0) + fmin(value, (ftype)0) * ftype(slope);
|
||||
}
|
||||
|
||||
kernel void relu_x4(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant float &slope [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
auto value = in[int(gid)];
|
||||
out[int(gid)] = fmax(value, (ftype4)0) + fmin(value, (ftype4)0) * ftype4(slope);
|
||||
}
|
||||
|
|
@ -1,13 +1,24 @@
|
|||
kernel void relu6_x1(const device ftype *in [[buffer(0)]],
|
||||
device ftype *out [[buffer(1)]],
|
||||
constant float4 &minMax [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
out[int(gid)] = clamp(in[int(gid)], (ftype)(minMax.x), (ftype)(minMax.y));
|
||||
struct Param {
|
||||
float minV;
|
||||
float maxV;
|
||||
int size;
|
||||
int remain;
|
||||
};
|
||||
kernel void relu6(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant Param &p [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if (gid.x < p.size) {
|
||||
out[int(gid.x)] = clamp(in[int(gid.x)], (ftype4)p.minV, (ftype4)p.maxV);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void relu6_x4(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant float4 &minMax [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
out[int(gid)] = clamp(in[int(gid)], (ftype4)minMax.x, (ftype4)minMax.y);
|
||||
}
|
||||
kernel void relu(const device ftype4 *in [[buffer(0)]],
|
||||
device ftype4 *out [[buffer(1)]],
|
||||
constant Param &p [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
if (gid.x < p.size) {
|
||||
auto value = in[int(gid.x)];
|
||||
out[int(gid.x)] = fmax(value, (ftype4)0) + fmin(value, (ftype4)0) * ftype4(p.minV);
|
||||
}
|
||||
}
|
||||
|
|
@ -12,7 +12,7 @@ kernel void scale_ca(const device ftype4 *in [[buffer(0)]],
|
|||
uint2 gid [[thread_position_in_grid]]) {
|
||||
if ((int)gid.x >= s.size || (int)gid.y >= s.steps * s.batch) return;
|
||||
|
||||
int z = gid.y % s.steps;
|
||||
int z = gid.y / s.batch;
|
||||
out[int(gid.y) * s.size + int(gid.x)] =
|
||||
in [int(gid.y) * s.size + int(gid.x)] * ftype4(scales[z]) + ftype4(biasTerms[z]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -294,6 +294,23 @@ private:
|
|||
ImagePool* mBufferPool;
|
||||
};
|
||||
|
||||
float OpenCLBackend::getBytes(const Tensor* tensor) {
|
||||
float bytes = (float)tensor->getType().bytes();
|
||||
if (getOpenCLRuntime()->isSupportedFP16()) {// Fp16
|
||||
if (halide_type_float == tensor->getType().code) {
|
||||
bytes = 2.0;
|
||||
}
|
||||
}
|
||||
auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get();
|
||||
if (nullptr != quant && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) {
|
||||
bytes = 1.0;
|
||||
}
|
||||
if(tensor->getType().bits == 4) {
|
||||
bytes = 0.5;
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageType storageType) {
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("Start OpenCLBackend::onAcquireBuffer !\n");
|
||||
|
|
@ -312,7 +329,7 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
|
|||
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
||||
if(mOpenCLRuntime->getGpuMemType() == BUFFER) {
|
||||
size_t size;
|
||||
size_t typeSize = 4;
|
||||
float typeSize = getBytes(nativeTensor);
|
||||
if (nativeTensor->dimensions() >= 2) {
|
||||
auto alignC = ROUND_UP(C, 8);
|
||||
// increment of height and width
|
||||
|
|
@ -331,16 +348,9 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
|
|||
size_t imageHeight = (size_t)N * H;
|
||||
size = imageWidth*imageHeight*cPack;
|
||||
}
|
||||
cl_channel_type dataType = CL_FLOAT;
|
||||
//when support and want fp16, use half datatype
|
||||
if ((nativeTensor->getType().code == halide_type_int || nativeTensor->getType().code == halide_type_uint)){
|
||||
if(nativeTensor->getType().bits == 8){
|
||||
typeSize = 1;
|
||||
}
|
||||
} else if (getOpenCLRuntime()->isSupportedFP16()) {
|
||||
typeSize = 2;
|
||||
}
|
||||
|
||||
// Align when int4 memory
|
||||
size = ROUND_UP(size, 2);
|
||||
|
||||
if (storageType == DYNAMIC_SEPERATE) {
|
||||
auto buffer = mBufferPool->alloc(size*typeSize, true);
|
||||
((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer;
|
||||
|
|
@ -352,23 +362,8 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
|
|||
return new CLMemReleaseBuffer(buffer, mBufferPool);
|
||||
}
|
||||
MNN_ASSERT(storageType == STATIC);
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
// for weight quant model's weight
|
||||
if ((nativeTensor->getType().code == halide_type_int) &&
|
||||
(nativeTensor->getType().bits == 8 || nativeTensor->getType().bits == 4)) {
|
||||
// int8 quant
|
||||
size_t alloc_size = size;
|
||||
if (nativeTensor->getType().bits == 4) {
|
||||
// int4 quant
|
||||
alloc_size = size / 2;
|
||||
}
|
||||
auto buffer = mStaticBufferPool->alloc(alloc_size);
|
||||
((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer;
|
||||
return new CLMemReleaseBuffer(buffer, mStaticBufferPool.get());
|
||||
}
|
||||
#endif
|
||||
auto buffer = mStaticBufferPool->alloc(size*
|
||||
(dataType == CL_HALF_FLOAT ? sizeof(half_float::half) : sizeof(float)));
|
||||
|
||||
auto buffer = mStaticBufferPool->alloc(size*typeSize);
|
||||
((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer; // fix
|
||||
return new CLMemReleaseBuffer(buffer, mStaticBufferPool.get());
|
||||
}
|
||||
|
|
@ -569,7 +564,7 @@ ErrorCode OpenCLBackend::onResizeEnd() {
|
|||
mOpenCLRuntime->setCommandQueueProfileDisable();
|
||||
#endif
|
||||
if(!mRecordings.empty()){
|
||||
endRecord(mRecordings.back(), true);
|
||||
endRecord(mRecordings.back().record, true);
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
|
@ -1188,8 +1183,25 @@ void OpenCLBackend::clearRecord() const{
|
|||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue && mDevideOpRecord){
|
||||
for(int i = 0; i < mRecordings.size(); ++i){
|
||||
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr,
|
||||
0, nullptr, 0, nullptr, 0, nullptr, nullptr);
|
||||
std::vector<cl_array_arg_qcom> update_kernel_args;
|
||||
std::vector<cl_workgroup_qcom> update_global_size;
|
||||
std::vector<cl_workgroup_qcom> update_local_size;
|
||||
for (int j = 0; j < mRecordings[i].updateInfo.size(); ++j){
|
||||
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_kernel_args.size(); ++k){
|
||||
update_kernel_args.emplace_back(mRecordings[i].updateInfo[j]->update_kernel_args[k]);
|
||||
update_kernel_args.back().dispatch_index = j;
|
||||
}
|
||||
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_global_size.size(); ++k){
|
||||
update_global_size.emplace_back(mRecordings[i].updateInfo[j]->update_global_size[k]);
|
||||
update_global_size.back().dispatch_index = j;
|
||||
}
|
||||
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_local_size.size(); ++k){
|
||||
update_local_size.emplace_back(mRecordings[i].updateInfo[j]->update_local_size[k]);
|
||||
update_local_size.back().dispatch_index = j;
|
||||
}
|
||||
}
|
||||
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i].record, update_kernel_args.size(), update_kernel_args.data(), 0, nullptr,
|
||||
update_global_size.size(), update_global_size.data(), update_local_size.size(), update_local_size.data(), 0, nullptr, nullptr);
|
||||
MNN_CHECK_CL_SUCCESS(res, "EnqueueRecordingQCOM");
|
||||
}
|
||||
mOpenCLRuntime->commandQueue().finish();
|
||||
|
|
@ -1202,8 +1214,22 @@ void OpenCLBackend::enqeueRecord() const{
|
|||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue && !mDevideOpRecord){
|
||||
for(int i = 0; i < mRecordings.size(); ++i){
|
||||
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr,
|
||||
0, nullptr, 0, nullptr, 0, nullptr, nullptr);
|
||||
std::vector<cl_array_arg_qcom> update_kernel_args;
|
||||
std::vector<cl_workgroup_qcom> update_global_size;
|
||||
std::vector<cl_workgroup_qcom> update_local_size;
|
||||
for (int j = 0; j < mRecordings[i].updateInfo.size(); ++j){
|
||||
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_kernel_args.size(); ++k){
|
||||
update_kernel_args.emplace_back(mRecordings[i].updateInfo[j]->update_kernel_args[k]);
|
||||
}
|
||||
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_global_size.size(); ++k){
|
||||
update_global_size.emplace_back(mRecordings[i].updateInfo[j]->update_global_size[k]);
|
||||
}
|
||||
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_local_size.size(); ++k){
|
||||
update_local_size.emplace_back(mRecordings[i].updateInfo[j]->update_local_size[k]);
|
||||
}
|
||||
}
|
||||
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i].record, update_kernel_args.size(), update_kernel_args.data(), 0, nullptr,
|
||||
update_global_size.size(), update_global_size.data(), update_local_size.size(), update_local_size.data(), 0, nullptr, nullptr);
|
||||
MNN_CHECK_CL_SUCCESS(res, "EnqueueRecordingQCOM");
|
||||
}
|
||||
mOpenCLRuntime->commandQueue().finish();
|
||||
|
|
@ -1215,7 +1241,7 @@ void OpenCLBackend::releaseRecord(){
|
|||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue && !mDevideOpRecord){
|
||||
for(int i = 0; i < mRecordings.size(); ++i){
|
||||
cl_int res = clReleaseRecordingQCOM(mRecordings[i]);
|
||||
cl_int res = clReleaseRecordingQCOM(mRecordings[i].record);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clReleaseRecordingQCOM");
|
||||
}
|
||||
mRecordings.clear();
|
||||
|
|
@ -1258,8 +1284,9 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){
|
|||
res = clEndRecordingQCOM(recording);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
} else if(flag) {
|
||||
// endRecord for last kernel be recorded when record mode is MNN_GPU_RECORD_BATCH
|
||||
if(!mRecordings.empty()){
|
||||
cl_int res = clEndRecordingQCOM(mRecordings.back());
|
||||
cl_int res = clEndRecordingQCOM(mRecordings.back().record);
|
||||
mRecordNums = 0;
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
}
|
||||
|
|
@ -1270,7 +1297,18 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){
|
|||
#endif //ENABLE_OPENCL_TIME_PROFILER
|
||||
}
|
||||
|
||||
void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws) {
|
||||
void OpenCLBackend::addRecord(cl_recording_qcom &record, std::vector<RecordUpdateInfo *>updateInfo){
|
||||
if(mDevideOpRecord){
|
||||
RecordInfo info;
|
||||
info.record = record;
|
||||
for(int i = 0; i < updateInfo.size(); ++i) {
|
||||
info.updateInfo.emplace_back(updateInfo[i]);
|
||||
}
|
||||
mRecordings.emplace_back(info);
|
||||
}
|
||||
}
|
||||
|
||||
void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo) {
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(!mUseRecordQueue){
|
||||
return;
|
||||
|
|
@ -1281,17 +1319,35 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, c
|
|||
#endif
|
||||
cl_int res = CL_SUCCESS;
|
||||
if(!mDevideOpRecord){
|
||||
RecordInfo info;
|
||||
if(updateInfo != nullptr){
|
||||
for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){
|
||||
updateInfo->update_kernel_args[i].dispatch_index = mRecordNums;
|
||||
}
|
||||
for(int i = 0; i < updateInfo->update_global_size.size(); ++i){
|
||||
updateInfo->update_global_size[i].dispatch_index = mRecordNums;
|
||||
}
|
||||
for(int i = 0; i < updateInfo->update_local_size.size(); ++i){
|
||||
updateInfo->update_local_size[i].dispatch_index = mRecordNums;
|
||||
}
|
||||
info.updateInfo.emplace_back(updateInfo);
|
||||
}
|
||||
if(mRecordNums == 0){
|
||||
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
mRecordings.emplace_back(recording);
|
||||
info.record = recording;
|
||||
mRecordings.emplace_back(info);
|
||||
}else if(mRecordNums == mUseRecordableQueueSize){
|
||||
res = clEndRecordingQCOM(mRecordings.back());
|
||||
res = clEndRecordingQCOM(mRecordings.back().record);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
mRecordings.emplace_back(recording);
|
||||
info.record = recording;
|
||||
mRecordings.emplace_back(info);
|
||||
mRecordNums = 0;
|
||||
} else if(updateInfo != nullptr){
|
||||
auto &lastInfo = mRecordings.back();
|
||||
lastInfo.updateInfo.emplace_back(updateInfo);
|
||||
}
|
||||
mRecordNums++;
|
||||
}
|
||||
|
|
@ -1317,7 +1373,7 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, c
|
|||
#endif //ENABLE_OPENCL_TIME_PROFILER
|
||||
}
|
||||
|
||||
void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws) {
|
||||
void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo) {
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(!mUseRecordQueue){
|
||||
return;
|
||||
|
|
@ -1332,17 +1388,35 @@ void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, c
|
|||
internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i]));
|
||||
}
|
||||
if(!mDevideOpRecord){
|
||||
RecordInfo info;
|
||||
if(updateInfo != nullptr){
|
||||
for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){
|
||||
updateInfo->update_kernel_args[i].dispatch_index = mRecordNums;
|
||||
}
|
||||
for(int i = 0; i < updateInfo->update_global_size.size(); ++i){
|
||||
updateInfo->update_global_size[i].dispatch_index = mRecordNums;
|
||||
}
|
||||
for(int i = 0; i < updateInfo->update_local_size.size(); ++i){
|
||||
updateInfo->update_local_size[i].dispatch_index = mRecordNums;
|
||||
}
|
||||
info.updateInfo.emplace_back(updateInfo);
|
||||
}
|
||||
if(mRecordNums == 0){
|
||||
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
mRecordings.emplace_back(recording);
|
||||
info.record = recording;
|
||||
mRecordings.emplace_back(info);
|
||||
}else if(mRecordNums == mUseRecordableQueueSize){
|
||||
res = clEndRecordingQCOM(mRecordings.back());
|
||||
res = clEndRecordingQCOM(mRecordings.back().record);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
mRecordings.emplace_back(recording);
|
||||
info.record = recording;
|
||||
mRecordings.emplace_back(info);
|
||||
mRecordNums = 0;
|
||||
} else if(updateInfo != nullptr){
|
||||
auto &lastInfo = mRecordings.back();
|
||||
lastInfo.updateInfo.emplace_back(updateInfo);
|
||||
}
|
||||
mRecordNums++;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,6 +34,15 @@
|
|||
namespace MNN {
|
||||
namespace OpenCL {
|
||||
struct TuneInfo;
|
||||
struct RecordUpdateInfo{
|
||||
std::vector<cl_array_arg_qcom> update_kernel_args;
|
||||
std::vector<cl_workgroup_qcom> update_global_size;
|
||||
std::vector<cl_workgroup_qcom> update_local_size;
|
||||
};
|
||||
struct RecordInfo{
|
||||
cl_recording_qcom record;
|
||||
std::vector<RecordUpdateInfo*> updateInfo;
|
||||
};
|
||||
class CLRuntime : public Runtime {
|
||||
public:
|
||||
CLRuntime(const Backend::Info& info, int platformSize, int platformId, int deviceId = 0, void *contextPtr = nullptr, void *glshared = nullptr);
|
||||
|
|
@ -111,6 +120,7 @@ public:
|
|||
return mMemory;
|
||||
}
|
||||
|
||||
float getBytes(const Tensor* tensor);
|
||||
DataType getDataType(const Tensor* tensor);
|
||||
|
||||
cl_channel_type fpType();
|
||||
|
|
@ -125,11 +135,9 @@ public:
|
|||
bool isDevideOpRecord(){
|
||||
return mDevideOpRecord;
|
||||
}
|
||||
void addRecord(cl_recording_qcom &record){
|
||||
mRecordings.emplace_back(record);
|
||||
}
|
||||
void recordKernel2d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws);
|
||||
void recordKernel3d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws);
|
||||
void addRecord(cl_recording_qcom &record, std::vector<RecordUpdateInfo *>updateInfo);
|
||||
void recordKernel2d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo = nullptr);
|
||||
void recordKernel3d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo = nullptr);
|
||||
void startRecord(cl_recording_qcom &recording);
|
||||
void endRecord(cl_recording_qcom &recording, bool flag = false);
|
||||
|
||||
|
|
@ -167,7 +175,7 @@ private:
|
|||
BackendConfig::PrecisionMode mPrecision;
|
||||
BackendConfig::MemoryMode mMemory;
|
||||
bool mIsCreateError{false};
|
||||
mutable std::vector<cl_recording_qcom> mRecordings;
|
||||
mutable std::vector<RecordInfo> mRecordings;
|
||||
bool mUseRecordQueue = false;
|
||||
bool mDevideOpRecord = false;
|
||||
uint32_t mRecordNums = 0;
|
||||
|
|
@ -202,6 +210,17 @@ public:
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef MNN_OPENCL_SEP_BUILD
|
||||
#define REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(name, opType, memObj) \
|
||||
OpenCLCreatorRegister<name> ___OpenCL##name##__##opType##__##memObj##__(opType, memObj)
|
||||
#else
|
||||
#define REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(name, opType, memObj) \
|
||||
void ___OpenCL##name##__##opType##__##memObj##__() { \
|
||||
static name _temp; \
|
||||
OpenCLBackend::addCreator(std::make_pair(opType, memObj), &_temp); \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
class TypedCreator : public OpenCLBackend::Creator {
|
||||
|
|
|
|||
|
|
@ -1,122 +1,127 @@
|
|||
// This file is generated by Shell for ops register
|
||||
#ifndef MNN_OPENCL_SEP_BUILD
|
||||
|
||||
namespace MNN {
|
||||
namespace OpenCL {
|
||||
extern void ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
|
||||
extern void ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
|
||||
extern void ___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__();
|
||||
extern void ___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__();
|
||||
extern void ___OpenCLScaleCreator__OpType_Scale__IMAGE__();
|
||||
extern void ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
|
||||
extern void ___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
|
||||
extern void ___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
|
||||
extern void ___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
|
||||
extern void ___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
|
||||
extern void ___OpenCLInterpCreator__OpType_Interp__IMAGE__();
|
||||
extern void ___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
|
||||
extern void ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
|
||||
extern void ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
|
||||
extern void ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
|
||||
extern void ___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__();
|
||||
extern void ___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
|
||||
extern void ___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
|
||||
extern void ___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
|
||||
extern void ___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
|
||||
extern void ___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
|
||||
extern void ___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
|
||||
extern void ___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
|
||||
extern void ___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
|
||||
extern void ___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
|
||||
extern void ___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__();
|
||||
extern void ___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
|
||||
extern void ___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
|
||||
extern void ___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
|
||||
extern void ___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
|
||||
extern void ___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__();
|
||||
extern void ___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
|
||||
extern void ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
|
||||
extern void ___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
|
||||
extern void ___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
|
||||
extern void ___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
|
||||
extern void ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
|
||||
extern void ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
|
||||
extern void ___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
|
||||
extern void ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
|
||||
extern void ___OpenCLReluCreator__OpType_ReLU__IMAGE__();
|
||||
extern void ___OpenCLReluCreator__OpType_PReLU__IMAGE__();
|
||||
extern void ___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
|
||||
extern void ___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
|
||||
extern void ___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
|
||||
extern void ___OpenCLReluBufCreator__OpType_ReLU__BUFFER__();
|
||||
extern void ___OpenCLReluBufCreator__OpType_PReLU__BUFFER__();
|
||||
extern void ___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__();
|
||||
extern void ___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
|
||||
extern void ___OpenCLRasterCreator__OpType_Raster__IMAGE__();
|
||||
extern void ___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
|
||||
extern void ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
|
||||
extern void ___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
|
||||
extern void ___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
|
||||
extern void ___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
|
||||
extern void ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
|
||||
extern void ___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
|
||||
extern void ___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
|
||||
extern void ___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
|
||||
extern void ___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
|
||||
extern void ___OpenCLLoopCreator__OpType_While__IMAGE__();
|
||||
extern void ___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
|
||||
extern void ___OpenCLLoopBufCreator__OpType_While__BUFFER__();
|
||||
extern void ___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
|
||||
extern void ___OpenCLFuseCreator__OpType_Extra__IMAGE__();
|
||||
extern void ___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
|
||||
extern void ___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
|
||||
extern void ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
|
||||
extern void ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
|
||||
extern void ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
|
||||
extern void ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
|
||||
extern void ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
|
||||
extern void ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
|
||||
extern void ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
|
||||
extern void ___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
|
||||
extern void ___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
|
||||
extern void ___OpenCLScaleCreator__OpType_Scale__IMAGE__();
|
||||
extern void ___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
|
||||
extern void ___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
|
||||
extern void ___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
|
||||
extern void ___OpenCLRangeCreator__OpType_Range__IMAGE__();
|
||||
extern void ___OpenCLCastCreator__OpType_Cast__IMAGE__();
|
||||
extern void ___OpenCLRasterCreator__OpType_Raster__IMAGE__();
|
||||
extern void ___OpenCLFuseCreator__OpType_Extra__IMAGE__();
|
||||
extern void ___OpenCLLoopCreator__OpType_While__IMAGE__();
|
||||
extern void ___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
|
||||
extern void ___OpenCLReluCreator__OpType_ReLU__IMAGE__();
|
||||
extern void ___OpenCLReluCreator__OpType_PReLU__IMAGE__();
|
||||
extern void ___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
|
||||
extern void ___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
|
||||
extern void ___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
|
||||
extern void ___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
|
||||
extern void ___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
|
||||
extern void ___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
|
||||
extern void ___OpenCLSelectCreator__OpType_Select__IMAGE__();
|
||||
extern void ___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
|
||||
extern void ___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
|
||||
extern void ___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
|
||||
extern void ___OpenCLCastCreator__OpType_Cast__IMAGE__();
|
||||
extern void ___OpenCLInterpCreator__OpType_Interp__IMAGE__();
|
||||
extern void ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
extern void ___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__();
|
||||
#endif
|
||||
void registerOpenCLOps() {
|
||||
___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__();
|
||||
___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__();
|
||||
___OpenCLScaleCreator__OpType_Scale__IMAGE__();
|
||||
___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
|
||||
___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
|
||||
___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
|
||||
___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
|
||||
___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
|
||||
___OpenCLInterpCreator__OpType_Interp__IMAGE__();
|
||||
___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
|
||||
___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
|
||||
___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
|
||||
___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
|
||||
___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__();
|
||||
___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
|
||||
___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
|
||||
___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
|
||||
___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
|
||||
___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
|
||||
___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__();
|
||||
___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
|
||||
___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__();
|
||||
___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
|
||||
___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
|
||||
___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
|
||||
___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
|
||||
___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
|
||||
___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
|
||||
___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
|
||||
___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
|
||||
___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
|
||||
___OpenCLReluCreator__OpType_ReLU__IMAGE__();
|
||||
___OpenCLReluCreator__OpType_PReLU__IMAGE__();
|
||||
___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
|
||||
___OpenCLReluBufCreator__OpType_ReLU__BUFFER__();
|
||||
___OpenCLReluBufCreator__OpType_PReLU__BUFFER__();
|
||||
___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__();
|
||||
___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
|
||||
___OpenCLRasterCreator__OpType_Raster__IMAGE__();
|
||||
___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
|
||||
___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
|
||||
___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
|
||||
___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
|
||||
___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
|
||||
___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
|
||||
___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
|
||||
___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
|
||||
___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
|
||||
___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
|
||||
___OpenCLLoopCreator__OpType_While__IMAGE__();
|
||||
___OpenCLLoopBufCreator__OpType_While__BUFFER__();
|
||||
___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
|
||||
___OpenCLFuseCreator__OpType_Extra__IMAGE__();
|
||||
___OpenCLRangeCreator__OpType_Range__IMAGE__();
|
||||
___OpenCLCastCreator__OpType_Cast__IMAGE__();
|
||||
___OpenCLSelectCreator__OpType_Select__IMAGE__();
|
||||
___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
|
||||
___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
|
||||
___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__();
|
||||
___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__();
|
||||
___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__();
|
||||
___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
|
||||
___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
|
||||
___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
|
||||
___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
|
||||
___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__();
|
||||
___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
|
||||
___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
|
||||
___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
|
||||
___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__();
|
||||
___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
|
||||
___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
|
||||
___OpenCLReluBufCreator__OpType_ReLU__BUFFER__();
|
||||
___OpenCLReluBufCreator__OpType_PReLU__BUFFER__();
|
||||
___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__();
|
||||
___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
|
||||
___OpenCLLoopBufCreator__OpType_While__BUFFER__();
|
||||
___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
|
||||
___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
|
||||
___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
|
||||
___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
|
||||
___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
|
||||
___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
|
||||
___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
|
||||
___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
|
||||
___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
|
||||
___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
|
||||
___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
|
||||
___OpenCLScaleCreator__OpType_Scale__IMAGE__();
|
||||
___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
|
||||
___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
|
||||
___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
|
||||
___OpenCLRangeCreator__OpType_Range__IMAGE__();
|
||||
___OpenCLRasterCreator__OpType_Raster__IMAGE__();
|
||||
___OpenCLFuseCreator__OpType_Extra__IMAGE__();
|
||||
___OpenCLLoopCreator__OpType_While__IMAGE__();
|
||||
___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
|
||||
___OpenCLReluCreator__OpType_ReLU__IMAGE__();
|
||||
___OpenCLReluCreator__OpType_PReLU__IMAGE__();
|
||||
___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
|
||||
___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
|
||||
___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
|
||||
___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
|
||||
___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
|
||||
___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
|
||||
___OpenCLSelectCreator__OpType_Select__IMAGE__();
|
||||
___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
|
||||
___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
|
||||
___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
|
||||
___OpenCLCastCreator__OpType_Cast__IMAGE__();
|
||||
___OpenCLInterpCreator__OpType_Interp__IMAGE__();
|
||||
___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -548,9 +548,7 @@ void runKernel2D(const ::std::shared_ptr<KernelWrap> &kernelw, const std::vector
|
|||
|
||||
void copyBufferToImage(OpenCLRuntime *runtime, const cl::Buffer &buffer, const cl::Image &image, int w, int h) {
|
||||
std::set<std::string> buildOptions;
|
||||
if(runtime->isWeightCpuTransHalf() == false) {
|
||||
buildOptions.emplace("-DBUFFER_INP_FP32");
|
||||
}
|
||||
buildOptions.emplace("-DBUFFER_INP_FP32");
|
||||
auto kernelW = runtime->buildKernelWithCache("copy_buffer_to_image2d", "copy_buffer_to_image2d", buildOptions);
|
||||
auto kernel = kernelW->get();
|
||||
auto status = kernel.setArg(0, buffer);
|
||||
|
|
|
|||
|
|
@ -230,9 +230,11 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
|
|||
mFirstGPUDevicePtr->getInfo(CL_DEVICE_MAX_COMPUTE_UNITS, &mGPUComputeUnits);
|
||||
mFirstGPUDevicePtr->getInfo(CL_DEVICE_MAX_CLOCK_FREQUENCY, &mMaxFreq);
|
||||
mFirstGPUDevicePtr->getInfo(CL_DEVICE_MAX_MEM_ALLOC_SIZE, &mMaxMemAllocSize);
|
||||
cl_device_fp_config fpConfig;
|
||||
cl_device_fp_config fpConfig;
|
||||
auto success = mFirstGPUDevicePtr->getInfo(CL_DEVICE_HALF_FP_CONFIG, &fpConfig);
|
||||
mIsDeviceSupportedFP16 = CL_SUCCESS == success && fpConfig > 0;
|
||||
bool checkFp16Exetension = getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_khr_fp16");
|
||||
mIsDeviceSupportedFP16 = (mIsDeviceSupportedFP16 && checkFp16Exetension);
|
||||
|
||||
//set gpu mode, tuning level and memory object
|
||||
setGpuMode(cl_mode);
|
||||
|
|
@ -245,11 +247,17 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
|
|||
}
|
||||
}
|
||||
|
||||
auto permitFloat16 = false;
|
||||
if (precision == BackendConfig::Precision_Low || (mMemType == BUFFER && precision == BackendConfig::Precision_Normal)) {//buffer mode not support Normal Precision yet
|
||||
permitFloat16 = true;
|
||||
mPrecisionLevel = 1;
|
||||
if (mIsDeviceSupportedFP16) {
|
||||
if (precision == BackendConfig::Precision_Low) {
|
||||
mPrecisionLevel = 2;
|
||||
} else if (precision == BackendConfig::Precision_Normal && mMemType == BUFFER) {
|
||||
mPrecisionLevel = 0;
|
||||
}
|
||||
}
|
||||
mIsSupportedFP16 = mIsDeviceSupportedFP16 && permitFloat16;
|
||||
|
||||
// Is supported fp16 IO storage
|
||||
mIsSupportedFP16 = (mPrecisionLevel == 2 || mPrecisionLevel == 0);
|
||||
|
||||
if(getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_arm_integer_dot_product_int8")){
|
||||
mSupportDotInt8 = true;
|
||||
|
|
@ -257,7 +265,7 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
|
|||
if(getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_arm_integer_dot_product_accumulate_int8")){
|
||||
mSupportDotAccInt8 = true;
|
||||
}
|
||||
|
||||
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
{
|
||||
if((false == OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isQcomError())
|
||||
|
|
@ -267,7 +275,7 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
|
|||
cl_int err;
|
||||
if(MaxRecordableQueueSize > 0){
|
||||
// TODO: Use setSessionHint to set the number of mUseRecordableQueueSize
|
||||
mUseRecordableQueueSize = 10;
|
||||
mUseRecordableQueueSize = MaxRecordableQueueSize;
|
||||
mUseRecordableQueueSize = MaxRecordableQueueSize < mUseRecordableQueueSize ? MaxRecordableQueueSize : mUseRecordableQueueSize;
|
||||
mUseRecordQueue = true;
|
||||
mRecordableQueuePtr = std::make_shared<cl::CommandQueue>(*mContext, *mFirstGPUDevicePtr, CL_QUEUE_RECORDABLE_QCOM, &err);
|
||||
|
|
@ -435,13 +443,6 @@ std::vector<size_t> OpenCLRuntime::getMaxImage2DSize() {
|
|||
bool OpenCLRuntime::isSupportedFP16() const {
|
||||
return mIsSupportedFP16;
|
||||
}
|
||||
bool OpenCLRuntime::isWeightCpuTransHalf() const {
|
||||
#ifdef USE_HALF_WEIGHT_MEMORY
|
||||
return mIsSupportedFP16;
|
||||
#else
|
||||
return false;//most of time
|
||||
#endif
|
||||
}
|
||||
|
||||
bool OpenCLRuntime::isDeviceSupportedFP16() const {
|
||||
return mIsDeviceSupportedFP16;
|
||||
|
|
@ -536,10 +537,12 @@ std::shared_ptr<KernelWrap> OpenCLRuntime::buildKernel(const std::string &progra
|
|||
std::shared_ptr<KernelWrap> OpenCLRuntime::buildKernelWithCache(const std::string &programName, const std::string &kernelName,
|
||||
const std::set<std::string> &buildOptions, const Tensor *input, const Tensor *output, bool useCache) {
|
||||
std::string buildOptionsStr;
|
||||
if (mIsSupportedFP16) {
|
||||
buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DRI_F=read_imageh -DWI_F=write_imageh -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DMNN_SUPPORT_FP16";
|
||||
} else {
|
||||
buildOptionsStr = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16";
|
||||
if (mPrecisionLevel == 2) {// Fp16 Memory and fp16 compute
|
||||
buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=half -DCOMPUTE_FLOAT2=half2 -DCOMPUTE_FLOAT3=half3 -DCOMPUTE_FLOAT4=half4 -DCOMPUTE_FLOAT8=half8 -DCOMPUTE_FLOAT16=half16 -DCONVERT_COMPUTE_FLOAT2=convert_half2 -DCONVERT_COMPUTE_FLOAT4=convert_half4 -DCONVERT_COMPUTE_FLOAT8=convert_half8 -DCONVERT_COMPUTE_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DMNN_SUPPORT_FP16";
|
||||
} else if (mPrecisionLevel == 0) {// Fp16 Memory and fp32 compute
|
||||
buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DMNN_SUPPORT_FP16";
|
||||
} else {// Fp32 Memory and fp32 compute
|
||||
buildOptionsStr = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16";
|
||||
}
|
||||
|
||||
if(nullptr != input){
|
||||
|
|
@ -907,6 +910,7 @@ void OpenCLRuntime::printEventTime(){
|
|||
return;
|
||||
}
|
||||
int raster_num = 0, raster_time = 0;
|
||||
unsigned int conv_time = 0, while_time = 0;
|
||||
for(int i = 0; i < mEvents.size(); ++i){
|
||||
auto event = &mEvents[i].second;
|
||||
cl_int res = event->wait();
|
||||
|
|
@ -915,10 +919,16 @@ void OpenCLRuntime::printEventTime(){
|
|||
auto StopNanos = event->getProfilingInfo<CL_PROFILING_COMMAND_END>();
|
||||
auto kernel_time = (unsigned int)((StopNanos - StartNanos) / 1000.0);
|
||||
mKernelTime += kernel_time;
|
||||
if(mEvents[i].first == "ConvBuf2D" || (mEvents[i].first.length() >= 11 && mEvents[i].first.substr(0, 11) == "Convolution")) {
|
||||
conv_time += kernel_time;
|
||||
}
|
||||
if((mEvents[i].first.length() >= 5 && mEvents[i].first.substr(0, 5) == "While")) {
|
||||
while_time += kernel_time;
|
||||
}
|
||||
MNN_PRINT("kernel time = %d us %s\n", kernel_time, mEvents[i].first.c_str());
|
||||
}
|
||||
mEvents.clear();
|
||||
MNN_PRINT("total kernel time = %d us\n", mKernelTime);
|
||||
MNN_PRINT("total kernel time = %d us, conv time = %d us, while time = %d us\n", mKernelTime, conv_time, while_time);
|
||||
#endif
|
||||
}
|
||||
} // namespace MNN
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ public:
|
|||
OpenCLRuntime &operator=(const OpenCLRuntime &) = delete;
|
||||
|
||||
bool isSupportedFP16() const;
|
||||
bool isWeightCpuTransHalf() const;
|
||||
bool isDeviceSupportedFP16() const;
|
||||
bool isDeviceSupportedLowPower() const;
|
||||
bool isSupportedDotInt8() const;
|
||||
|
|
@ -197,7 +196,9 @@ private:
|
|||
uint32_t mUseRecordableQueueSize;
|
||||
bool mUseRecordQueue = false;
|
||||
bool mDevideOpRecord = true;
|
||||
bool mIsSupportedFP16 = false;
|
||||
int mPrecisionLevel;
|
||||
|
||||
bool mIsSupportedFP16 = false;
|
||||
bool mIsDeviceSupportedFP16 = false;
|
||||
bool mIsDeviceSupportedLowPower = false;
|
||||
bool mSupportDotInt8 = false;
|
||||
|
|
|
|||
|
|
@ -33,6 +33,14 @@ bool OpenCLSymbols::LoadOpenCLLibrary() {
|
|||
"libGLES_mali.so",
|
||||
"libmali.so",
|
||||
"libOpenCL-pixel.so",
|
||||
/*
|
||||
#elif defined(__OHOS__)
|
||||
"/vendor/lib64/chipsetsdk/libGLES_mali.so",
|
||||
"/system/lib64/libGLES_mali.so",
|
||||
"libGLES_mali.so",
|
||||
"/vendor/lib64/chipsetsdk/libhvgr_v200.so",
|
||||
"/vendor/lib64/chipsetsdk/libEGI_imp1.so",
|
||||
*/
|
||||
#if defined(__aarch64__)
|
||||
// Qualcomm Adreno
|
||||
"/system/vendor/lib64/libOpenCL.so",
|
||||
|
|
@ -110,7 +118,7 @@ bool OpenCLSymbols::isSvmError() {
|
|||
bool OpenCLSymbols::isPropError() {
|
||||
return mPropError;
|
||||
}
|
||||
|
||||
|
||||
bool OpenCLSymbols::isQcomError() {
|
||||
return mQcomError;
|
||||
}
|
||||
|
|
@ -118,11 +126,11 @@ bool OpenCLSymbols::isQcomError() {
|
|||
bool OpenCLSymbols::isGlError() {
|
||||
return mGlError;
|
||||
}
|
||||
|
||||
|
||||
bool OpenCLSymbols::isCL1_2Error() {
|
||||
return mCL_12Error;
|
||||
}
|
||||
|
||||
|
||||
bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
||||
#if defined(WIN32)
|
||||
handle_ = LoadLibraryA(library_path.c_str());
|
||||
|
|
@ -133,38 +141,38 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mIsError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_SVM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
|
||||
if(func_name == nullptr){ \
|
||||
mSvmError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_PROP_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
|
||||
if(func_name == nullptr){ \
|
||||
mPropError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_QCOM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
|
||||
if(func_name == nullptr){ \
|
||||
mQcomError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_CL_12_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
|
||||
if(func_name == nullptr){ \
|
||||
mCL_12Error = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_GL_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
|
||||
if(func_name == nullptr){ \
|
||||
mGlError = true; \
|
||||
}
|
||||
|
||||
|
||||
#else
|
||||
handle_ = dlopen(library_path.c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
if (handle_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
typedef void* (*loadOpenCLPointerFunc)(const char* name);
|
||||
typedef void (*enableOpenCLFunc)();
|
||||
loadOpenCLPointerFunc loadOpenCLPointer = nullptr;
|
||||
|
|
@ -180,7 +188,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mIsError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_SVM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
|
||||
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
|
||||
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
|
||||
|
|
@ -188,7 +196,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mSvmError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_PROP_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
|
||||
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
|
||||
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
|
||||
|
|
@ -196,7 +204,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mPropError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_QCOM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
|
||||
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
|
||||
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
|
||||
|
|
@ -204,7 +212,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mQcomError = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_CL_12_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
|
||||
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
|
||||
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
|
||||
|
|
@ -212,7 +220,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mCL_12Error = true; \
|
||||
}
|
||||
|
||||
|
||||
#define MNN_LOAD_GL_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
|
||||
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
|
||||
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
|
||||
|
|
@ -220,7 +228,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
if(func_name == nullptr){ \
|
||||
mGlError = true; \
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
MNN_LOAD_FUNCTION_PTR(clGetPlatformIDs);
|
||||
|
|
@ -277,7 +285,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
|
|||
MNN_LOAD_CL_12_PTR(clCreateImage);
|
||||
MNN_LOAD_CL_12_PTR(clRetainDevice);
|
||||
MNN_LOAD_CL_12_PTR(clReleaseDevice);
|
||||
|
||||
|
||||
MNN_LOAD_PROP_PTR(clCreateCommandQueueWithProperties);
|
||||
MNN_LOAD_SVM_PTR(clSVMAlloc);
|
||||
MNN_LOAD_SVM_PTR(clSVMFree);
|
||||
|
|
@ -707,7 +715,7 @@ cl_mem CL_API_CALL clCreateFromGLTexture(cl_context context,
|
|||
auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateFromGLTexture;
|
||||
MNN_CHECK_NOTNULL(func);
|
||||
return func(context, flags, target, miplevel, texture, errcode_ret);
|
||||
|
||||
|
||||
}
|
||||
|
||||
cl_int CL_API_CALL clEnqueueAcquireGLObjects(cl_command_queue command_queue,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,427 @@
|
|||
//
|
||||
// SoftmaxBufExecution.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2024/04/11.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
||||
|
||||
#include "backend/opencl/execution/buffer/AttentionBufExecution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace OpenCL {
|
||||
|
||||
AttentionBufImpl::AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cahce)
|
||||
: mKVCache(kv_cahce){
|
||||
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
|
||||
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_channel", {"-DSOFTMAX_LOCAL_SIZE=512"});
|
||||
mMaxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel));
|
||||
}
|
||||
|
||||
void AttentionBufImpl::allocKVCache() {
|
||||
if (!mKVCache || mPastLength < mMaxLength) {
|
||||
return;
|
||||
}
|
||||
mMaxLength = mPastLength + mExpandChunk;
|
||||
int byte = 4;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
|
||||
byte = 2;
|
||||
}
|
||||
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * mHeadDim * 4 * byte;
|
||||
// past_key: [1, numhead, headdim, maxlen]
|
||||
mPastKey.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
// past_value: [1, numhead, maxlen, headdim]
|
||||
mPastValue.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
}
|
||||
|
||||
void AttentionBufImpl::reallocKVCache() {
|
||||
if (!mKVCache || mPastLength < mMaxLength) {
|
||||
return;
|
||||
}
|
||||
int byte = 4;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
|
||||
byte = 2;
|
||||
}
|
||||
size_t old_size = mNumHead * UP_DIV(mMaxLength, 4) * mHeadDim * 4 * byte;
|
||||
mMaxLength = mPastLength + mExpandChunk;
|
||||
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * mHeadDim * 4 * byte;
|
||||
// past_key: [1, numhead, headdim, maxlen]
|
||||
auto new_key = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
|
||||
// past_value: [1, numhead, maxlen, headdim]
|
||||
auto new_value = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
|
||||
// copy
|
||||
cl_int res;
|
||||
auto new_key_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*new_key, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
|
||||
auto key_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastKey.get(), true, CL_MAP_READ, 0, old_size, nullptr, nullptr, &res);
|
||||
if(new_key_ptr != nullptr && key_ptr != nullptr && res == CL_SUCCESS){
|
||||
::memcpy(new_key_ptr, key_ptr, old_size);
|
||||
}else{
|
||||
MNN_ERROR("Map error key_ptr == nullptr \n");
|
||||
MNN_ASSERT(false);
|
||||
}
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*new_key, new_key_ptr);
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mPastKey.get(), key_ptr);
|
||||
|
||||
auto new_value_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*new_value, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
|
||||
auto value_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastValue.get(), true, CL_MAP_READ, 0, old_size, nullptr, nullptr, &res);
|
||||
if(new_value_ptr != nullptr && value_ptr != nullptr && res == CL_SUCCESS){
|
||||
::memcpy(new_value_ptr, value_ptr, old_size);
|
||||
}else{
|
||||
MNN_ERROR("Map error value_ptr == nullptr \n");
|
||||
MNN_ASSERT(false);
|
||||
}
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*new_value, new_value_ptr);
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mPastValue.get(), value_ptr);
|
||||
|
||||
mPastKey.reset(new_key);
|
||||
mPastValue.reset(new_value);
|
||||
size_t temp_size = UP_DIV(mMaxLength, 4) * mNumHead * 4 * byte;
|
||||
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, temp_size));
|
||||
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, temp_size));
|
||||
|
||||
// reset memory for args
|
||||
if(mOpenCLBackend->isUseRecordQueue()){
|
||||
mQkUpdateInfo.update_kernel_args[1].arg_value = &(*(mTempQK.get()))();
|
||||
mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mPastKey.get()))();
|
||||
mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &(*(mTempQK.get()))();
|
||||
mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &(*(mTempSoftMax.get()))();
|
||||
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mTempSoftMax.get()))();
|
||||
mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mPastValue.get()))();
|
||||
}else{
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= mKernel_qk->get().setArg(5, *mTempQK.get());
|
||||
ret |= mKernel_qk->get().setArg(6, *mPastKey.get());
|
||||
ret |= mKernel_softmax->get().setArg(3, *mTempQK.get());
|
||||
ret |= mKernel_softmax->get().setArg(4, *mTempSoftMax.get());
|
||||
ret |= mKernel_qkv->get().setArg(3, *mTempSoftMax.get());
|
||||
ret |= mKernel_qkv->get().setArg(6, *mPastValue.get());
|
||||
MNN_CHECK_CL_SUCCESS(ret, "reset memory arg for AttentionBufExecution");
|
||||
}
|
||||
}
|
||||
|
||||
int AttentionBufImpl::getLocalSize(int size, int maxGroupSize){
|
||||
int local_size = 1;
|
||||
while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){
|
||||
local_size *= 2;
|
||||
}
|
||||
return local_size;
|
||||
}
|
||||
|
||||
ErrorCode AttentionBufImpl::onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
|
||||
mOpenCLBackend->startRecord(mRecording);
|
||||
//clear update arg vector, if prefill and decode use the same one
|
||||
mOpRecordUpdateInfo.clear();
|
||||
mQkUpdateInfo.update_kernel_args.clear();
|
||||
mQkUpdateInfo.update_global_size.clear();
|
||||
mQkUpdateInfo.update_local_size.clear();
|
||||
mSoftMaxUpdateInfo.update_kernel_args.clear();
|
||||
mSoftMaxUpdateInfo.update_global_size.clear();
|
||||
mSoftMaxUpdateInfo.update_local_size.clear();
|
||||
mQkvUpdateInfo.update_kernel_args.clear();
|
||||
mQkvUpdateInfo.update_global_size.clear();
|
||||
mQkvUpdateInfo.update_local_size.clear();
|
||||
|
||||
auto query = inputs[0];
|
||||
auto key = inputs[1];
|
||||
auto value = inputs[2];
|
||||
auto mask = inputs[3];
|
||||
auto runtime = mOpenCLBackend->getOpenCLRuntime();
|
||||
auto shape = query->shape();
|
||||
|
||||
int seq_len = shape[1];
|
||||
mNumHead = shape[2];
|
||||
mHeadDim = shape[3];
|
||||
mScale = 1.0 / sqrt(mHeadDim);
|
||||
mIsDecode = seq_len == 1;
|
||||
mIsFirstDecode = true;
|
||||
if (mPastLength == 0 || seq_len > 1) {
|
||||
mPastLength = seq_len;
|
||||
}
|
||||
mKv_seq_len = mPastLength;
|
||||
if(mIsDecode){
|
||||
mKv_seq_len = mPastLength + 1;
|
||||
}
|
||||
|
||||
allocKVCache();
|
||||
int byte = 4;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
|
||||
byte = 2;
|
||||
}
|
||||
if (mIsDecode) {
|
||||
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * 4 * byte;
|
||||
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
} else {
|
||||
size_t buffer_size = UP_DIV(mPastLength, 4) * mPastLength * mNumHead * 4 * byte;
|
||||
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
}
|
||||
|
||||
// query * key -> div -> select
|
||||
{
|
||||
std::set<std::string> buildOption;
|
||||
if(!mIsDecode){
|
||||
buildOption.emplace("-DOPENCL_PREFILL_ATTENTION");
|
||||
}
|
||||
if((mHeadDim % 4) != 0){
|
||||
buildOption.emplace("-DHEADDIM_LEAVE");
|
||||
}
|
||||
mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask", buildOption, inputs[0], outputs[0]);
|
||||
mGlobalWorkSizeQk = {static_cast<uint32_t>(UP_DIV(seq_len, 4)), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(UP_DIV(mKv_seq_len, 4))};
|
||||
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qk));
|
||||
mGlobalWorkSizeQk2 = UP_DIV(mKv_seq_len, 4);
|
||||
|
||||
uint32_t index = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]);
|
||||
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]);
|
||||
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk2);
|
||||
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query));
|
||||
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(key));
|
||||
ret |= mKernel_qk->get().setArg(index++, *mTempQK.get());
|
||||
ret |= mKernel_qk->get().setArg(index++, *mPastKey.get());
|
||||
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mask));
|
||||
ret |= mKernel_qk->get().setArg(index++, mScale);
|
||||
ret |= mKernel_qk->get().setArg(index++, seq_len);
|
||||
ret |= mKernel_qk->get().setArg(index++, mKv_seq_len);
|
||||
ret |= mKernel_qk->get().setArg(index++, mNumHead);
|
||||
ret |= mKernel_qk->get().setArg(index++, mHeadDim);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_div_mask");
|
||||
|
||||
mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask", mKernel_qk).first;
|
||||
mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
|
||||
mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
|
||||
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2]));
|
||||
mQkUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(mGlobalWorkSizeQk2), &mGlobalWorkSizeQk2});
|
||||
mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &(*(mTempQK.get()))()});
|
||||
mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastKey.get()))()});
|
||||
mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKv_seq_len), &mKv_seq_len});
|
||||
mQkGlobal_size[0] = mGlobalWorkSizeQk[0];
|
||||
mQkGlobal_size[1] = mGlobalWorkSizeQk[1];
|
||||
mQkGlobal_size[2] = mGlobalWorkSizeQk[2];
|
||||
mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size});
|
||||
mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo);
|
||||
mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo);
|
||||
}
|
||||
|
||||
// softmax
|
||||
{
|
||||
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
|
||||
int localSize = getLocalSize(mKv_seq_len, MaxLocalSize);
|
||||
if(localSize < 4){
|
||||
localSize = 1;
|
||||
}
|
||||
int past_len4 = UP_DIV(mKv_seq_len, 4);
|
||||
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
|
||||
mSoftmaxShape[0] = mNumHead;
|
||||
mSoftmaxShape[1] = past_len4;
|
||||
mSoftmaxShape[2] = 1;
|
||||
mSoftmaxShape[3] = mPastLength;
|
||||
std::set<std::string> buildOption;
|
||||
buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
|
||||
if(!mIsDecode){
|
||||
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_width", buildOption, inputs[0], outputs[0]);
|
||||
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(past_len4), static_cast<uint32_t>(mNumHead)};
|
||||
} else{
|
||||
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_channel", buildOption, inputs[0], outputs[0]);
|
||||
mSoftmaxShape[3] = 1;
|
||||
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(1), static_cast<uint32_t>(mNumHead)};
|
||||
}
|
||||
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_softmax));
|
||||
|
||||
uint32_t index = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]);
|
||||
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]);
|
||||
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]);
|
||||
ret |= mKernel_softmax->get().setArg(index++, *mTempQK.get());
|
||||
ret |= mKernel_softmax->get().setArg(index++, *mTempSoftMax.get());
|
||||
ret |= mKernel_softmax->get().setArg(index++, mSoftMaxRemainChannels);
|
||||
ret |= mKernel_softmax->get().setArg(index++, mSoftmaxShape);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg softmax");
|
||||
|
||||
mLocalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), 1, 1};
|
||||
if(localSize == 1){
|
||||
mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax).first;
|
||||
}
|
||||
mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
|
||||
mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
|
||||
mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
|
||||
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mTempQK.get()))()});
|
||||
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mTempSoftMax.get()))()});
|
||||
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mSoftMaxRemainChannels), &mSoftMaxRemainChannels});
|
||||
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mSoftmaxShape), &mSoftmaxShape});
|
||||
mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
|
||||
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo);
|
||||
}
|
||||
|
||||
// qk * value
|
||||
{
|
||||
std::set<std::string> buildOption;
|
||||
if(!mIsDecode){
|
||||
buildOption.emplace("-DOPENCL_PREFILL_ATTENTION");
|
||||
}
|
||||
if((mHeadDim % 4) != 0){
|
||||
buildOption.emplace("-DHEADDIM_LEAVE");
|
||||
}
|
||||
mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv", buildOption, inputs[0], outputs[0]);
|
||||
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qkv));
|
||||
mGlobalWorkSizeQkv = {static_cast<uint32_t>(UP_DIV(seq_len, 4)), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(mHeadDim)};
|
||||
if(mIsDecode){
|
||||
mGlobalWorkSizeQkv = {static_cast<uint32_t>(1), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(UP_DIV(mHeadDim, 4))};
|
||||
}
|
||||
|
||||
uint32_t index = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]);
|
||||
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]);
|
||||
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]);
|
||||
ret |= mKernel_qkv->get().setArg(index++, *mTempSoftMax.get());
|
||||
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(value));
|
||||
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0]));
|
||||
ret |= mKernel_qkv->get().setArg(index++, *mPastValue.get());
|
||||
ret |= mKernel_qkv->get().setArg(index++, seq_len);
|
||||
ret |= mKernel_qkv->get().setArg(index++, mKv_seq_len);
|
||||
ret |= mKernel_qkv->get().setArg(index++, mNumHead);
|
||||
ret |= mKernel_qkv->get().setArg(index++, mHeadDim);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv");
|
||||
|
||||
mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv", mKernel_qkv).first;
|
||||
mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0]));
|
||||
mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1]));
|
||||
mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2]));
|
||||
|
||||
mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mTempSoftMax.get()))()});
|
||||
mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastValue.get()))()});
|
||||
mQkvUpdateInfo.update_kernel_args.push_back({0, 8, sizeof(mKv_seq_len), &mKv_seq_len});
|
||||
mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo);
|
||||
mOpenCLBackend->recordKernel3d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo);
|
||||
}
|
||||
|
||||
mOpenCLBackend->endRecord(mRecording);
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("start AttentionBufExecution onExecute !\n");
|
||||
#endif
|
||||
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
|
||||
reallocKVCache();
|
||||
#ifdef ENABLE_OPENCL_TIME_PROFILER
|
||||
{
|
||||
cl::Event event;
|
||||
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk,
|
||||
mOpenCLBackend->getOpenCLRuntime(), &event);
|
||||
|
||||
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event});
|
||||
}
|
||||
{
|
||||
cl::Event event;
|
||||
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax,
|
||||
mOpenCLBackend->getOpenCLRuntime(), &event);
|
||||
|
||||
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event});
|
||||
}
|
||||
{
|
||||
cl::Event event;
|
||||
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv,
|
||||
mOpenCLBackend->getOpenCLRuntime(), &event);
|
||||
|
||||
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event});
|
||||
}
|
||||
#else
|
||||
if(mOpenCLBackend->isUseRecordQueue()){
|
||||
mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo);
|
||||
if(mIsDecode){
|
||||
if(mIsFirstDecode){
|
||||
mIsFirstDecode = false;
|
||||
}else{
|
||||
mPastLength += 1;
|
||||
mKv_seq_len = mPastLength + 1;
|
||||
int past_len4 = UP_DIV(mKv_seq_len, 4);
|
||||
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
|
||||
mSoftmaxShape[1] = past_len4;
|
||||
mGlobalWorkSizeQk2 = past_len4;
|
||||
mQkGlobal_size[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2]));
|
||||
}
|
||||
}
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End AttentionBufExecution onExecute... \n");
|
||||
#endif
|
||||
return NO_ERROR;
|
||||
}
|
||||
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime());
|
||||
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime());
|
||||
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime());
|
||||
#endif
|
||||
|
||||
// decode
|
||||
if(mIsDecode){
|
||||
mPastLength += 1;
|
||||
mKv_seq_len = mPastLength + 1;
|
||||
int past_len4 = UP_DIV(mKv_seq_len, 4);
|
||||
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
|
||||
mSoftmaxShape[1] = past_len4;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
mGlobalWorkSizeQk2 = past_len4;
|
||||
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2]));
|
||||
ret |= mKernel_qk->get().setArg(2, mGlobalWorkSizeQk2);
|
||||
ret |= mKernel_qk->get().setArg(10, mKv_seq_len);
|
||||
ret |= mKernel_softmax->get().setArg(5, mSoftMaxRemainChannels);
|
||||
ret |= mKernel_softmax->get().setArg(6, mSoftmaxShape);
|
||||
ret |= mKernel_qkv->get().setArg(8, mKv_seq_len);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "reset arg for AttentionBufExecution");
|
||||
}
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("end AttentionBufExecution onExecute !\n");
|
||||
#endif
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
AttentionBufExecution::AttentionBufExecution(const MNN::Op *op, Backend* backend, bool kv_cahce) : CommonExecution(backend, op) {
|
||||
mImpl.reset(new AttentionBufImpl(op, backend, kv_cahce));
|
||||
}
|
||||
|
||||
AttentionBufExecution::AttentionBufExecution(std::shared_ptr<AttentionBufImpl> impl, const MNN::Op *op, Backend *backend) : CommonExecution(backend, op), mImpl(impl) {}
|
||||
|
||||
ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
return mImpl->onResize(backend(), inputs, outputs);
|
||||
}
|
||||
|
||||
ErrorCode AttentionBufExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
return mImpl->onExecute(backend(), inputs, outputs);
|
||||
}
|
||||
|
||||
bool AttentionBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
|
||||
if (nullptr == dst) {
|
||||
return true;
|
||||
}
|
||||
*dst = new AttentionBufExecution(mImpl, op, bn);
|
||||
return true;
|
||||
}
|
||||
|
||||
class AttentionBufCreator : public OpenCLBackend::Creator {
|
||||
public:
|
||||
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||
const MNN::Op *op, Backend *backend) const override {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
TensorUtils::setTensorSupportPack(inputs[i], false);
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
TensorUtils::setTensorSupportPack(outputs[i], false);
|
||||
}
|
||||
auto param = op->main_as_AttentionParam();
|
||||
return new AttentionBufExecution(op, backend, param->kv_cache());
|
||||
}
|
||||
};
|
||||
REGISTER_OPENCL_OP_CREATOR(AttentionBufCreator, OpType_Attention, BUFFER);
|
||||
|
||||
} // namespace OpenCL
|
||||
} // namespace MNN
|
||||
#endif/* MNN_OPENCL_BUFFER_CLOSED */
|
||||
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
//
|
||||
// AttentionBufExecution.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2024/04/11.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
||||
|
||||
#ifndef AttentionBufExecution_hpp
|
||||
#define AttentionBufExecution_hpp
|
||||
|
||||
#include "backend/opencl/execution/image/CommonExecution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace OpenCL {
|
||||
|
||||
class AttentionBufImpl {
|
||||
public:
|
||||
AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cache);
|
||||
|
||||
~AttentionBufImpl() {
|
||||
if(mRecording != NULL){
|
||||
#ifdef MNN_USE_LIB_WRAPPER
|
||||
clReleaseRecordingQCOM(mRecording);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
ErrorCode onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
|
||||
ErrorCode onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
|
||||
|
||||
private:
|
||||
int getLocalSize(int size, int maxGroupSize);
|
||||
void allocKVCache();
|
||||
void reallocKVCache();
|
||||
bool mKVCache;
|
||||
float mScale;
|
||||
const int mExpandChunk = 64;
|
||||
bool mIsDecode = false;
|
||||
bool mIsFirstDecode = true;
|
||||
int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0, mSoftMaxRemainChannels = 0;
|
||||
std::shared_ptr<cl::Buffer> mPastKey, mPastValue, mTempQK, mTempSoftMax;
|
||||
int mNumHead = 0, mHeadDim = 0, mValueH = 0;
|
||||
std::shared_ptr<KernelWrap> mKernel_qk;
|
||||
std::shared_ptr<KernelWrap> mKernel_softmax;
|
||||
std::shared_ptr<KernelWrap> mKernel_qkv;
|
||||
std::vector<uint32_t> mGlobalWorkSizeQk{1, 1, 1};
|
||||
std::vector<uint32_t> mLocalWorkSizeQk{1, 1, 1, 1};
|
||||
std::vector<uint32_t> mGlobalWorkSizeSoftMax{1, 1, 1};
|
||||
std::vector<uint32_t> mLocalWorkSizeSoftMax{1, 1, 1, 1};
|
||||
std::vector<uint32_t> mGlobalWorkSizeQkv{1, 1, 1};
|
||||
std::vector<uint32_t> mLocalWorkSizeQkv{1, 1, 1, 1};
|
||||
uint32_t mMaxWorkGroupSize;
|
||||
OpenCLBackend *mOpenCLBackend;
|
||||
RecordUpdateInfo mQkUpdateInfo;
|
||||
RecordUpdateInfo mSoftMaxUpdateInfo;
|
||||
RecordUpdateInfo mQkvUpdateInfo;
|
||||
int mGlobalWorkSizeQk2 = 0;
|
||||
size_t mQkGlobal_size[3];
|
||||
int mSoftmaxShape[4];
|
||||
cl_recording_qcom mRecording{NULL};
|
||||
std::vector<RecordUpdateInfo*> mOpRecordUpdateInfo;
|
||||
};
|
||||
|
||||
class AttentionBufExecution : public CommonExecution {
|
||||
public:
|
||||
AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache);
|
||||
AttentionBufExecution(std::shared_ptr<AttentionBufImpl> impl, const MNN::Op *op, Backend *backend);
|
||||
|
||||
virtual ~AttentionBufExecution() = default;
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<AttentionBufImpl> mImpl;
|
||||
};
|
||||
} // namespace OpenCL
|
||||
} // namespace MNN
|
||||
#endif /* AttentionBufExecution_hpp */
|
||||
#endif /* MNN_OPENCL_BUFFER_CLOSED */
|
||||
|
|
@ -379,7 +379,7 @@ public:
|
|||
case BinaryOpOperation_NOTEQUAL:
|
||||
return new BinaryBufExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend);
|
||||
case BinaryOpOperation_MOD:
|
||||
return new BinaryBufExecution(inputs, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend);
|
||||
return new BinaryBufExecution(inputs, "fmod(in0,in1)", op, backend);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -354,12 +354,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
|
|||
std::shared_ptr<Tensor> filterBuffer(
|
||||
Tensor::createDevice<float>({mResource->mOutputChannel, ROUND_UP(mResource->mInputChannel, 4), mResource->mKernelWidth, mResource->mKernelHeight}));
|
||||
|
||||
int buffer_size = filterBuffer->elementSize();
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
|
||||
buffer_size *= sizeof(half_float::half);
|
||||
} else {
|
||||
buffer_size *= sizeof(float);
|
||||
}
|
||||
int buffer_size = filterBuffer->elementSize() * sizeof(float);
|
||||
cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
|
||||
filterBuffer->buffer().device = (uint64_t)(&filterBufferCL);
|
||||
|
||||
|
|
@ -367,25 +362,10 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
|
|||
auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
|
||||
if(ptrCL != nullptr && res == CL_SUCCESS) {
|
||||
::memset(ptrCL, 0, buffer_size);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
|
||||
for(int oc=0; oc<mResource->mOutputChannel; oc++) {
|
||||
for(int ic=0; ic<mResource->mInputChannel; ic++) {
|
||||
for(int kh=0; kh<mResource->mKernelHeight; kh++) {
|
||||
for(int kw=0; kw<mResource->mKernelWidth; kw++) {
|
||||
int dst_idx = ((oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelHeight + kh)* mResource->mKernelWidth + kw;
|
||||
int src_idx = ((oc * mResource->mInputChannel + ic) * mResource->mKernelHeight + kh)* mResource->mKernelWidth + kw;
|
||||
|
||||
((half_float::half*)ptrCL)[dst_idx] = (half_float::half)(mFilterDataPtr[src_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
const int copy_size = mResource->mKernelWidth * mResource->mKernelHeight * sizeof(float);
|
||||
for(int oc=0; oc<mResource->mOutputChannel; oc++) {
|
||||
for(int ic=0; ic<mResource->mInputChannel; ic++) {
|
||||
::memcpy((float *)ptrCL + (oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelWidth * mResource->mKernelHeight, mFilterDataPtr + (oc * mResource->mInputChannel + ic) * mResource->mKernelWidth * mResource->mKernelHeight, copy_size);
|
||||
}
|
||||
const int copy_size = mResource->mKernelWidth * mResource->mKernelHeight * sizeof(float);
|
||||
for(int oc=0; oc<mResource->mOutputChannel; oc++) {
|
||||
for(int ic=0; ic<mResource->mInputChannel; ic++) {
|
||||
::memcpy((float *)ptrCL + (oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelWidth * mResource->mKernelHeight, mFilterDataPtr + (oc * mResource->mInputChannel + ic) * mResource->mKernelWidth * mResource->mKernelHeight, copy_size);
|
||||
}
|
||||
}
|
||||
}else{
|
||||
|
|
@ -397,10 +377,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
|
|||
mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
|
||||
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
|
||||
|
||||
bool needTrans = false;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
|
||||
needTrans = true;
|
||||
}
|
||||
bool needTrans = true;
|
||||
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), needTrans);
|
||||
}
|
||||
}
|
||||
|
|
@ -697,8 +674,7 @@ ErrorCode ConvBufExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"ConvBuf2D", event});
|
||||
#else
|
||||
if(mOpenCLBackend->isUseRecordQueue()){
|
||||
if(mOpenCLBackend->isDevideOpRecord())
|
||||
mOpenCLBackend->addRecord(mRecording);
|
||||
mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End ConvExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
@ -729,7 +705,7 @@ public:
|
|||
#ifdef MNN_LOW_MEMORY
|
||||
{
|
||||
auto conv2dParams = op->main_as_Convolution2D();
|
||||
if ((static_cast<OpenCLBackend *>(backend)->getMemory() == BackendConfig::Memory_Low) && (conv2dParams->quanParameter() != nullptr)) {
|
||||
if (conv2dParams->quanParameter() != nullptr) {
|
||||
if (((conv2dParams->quanParameter()->type() == 4) ||
|
||||
(conv2dParams->quanParameter()->type() == 1) ||
|
||||
(conv2dParams->quanParameter()->type() == 2))) {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ struct ConvBufResource {
|
|||
const Convolution2DCommon *mConv2dCommonParams;
|
||||
const Convolution2D *mConv2dParams;
|
||||
std::shared_ptr<cl::Buffer> mKernelBuffer;
|
||||
std::shared_ptr<cl::Image2D> mKernelImage;
|
||||
std::shared_ptr<Tensor> dequantScale;
|
||||
std::shared_ptr<Tensor> dequantOffset;
|
||||
std::shared_ptr<Tensor> mFilter;
|
||||
|
|
@ -33,6 +34,7 @@ struct ConvBufResource {
|
|||
bool mConv1x1Opt = false;
|
||||
bool mConv1x1C8Opt = false;
|
||||
std::shared_ptr<Execution> mRasterExe;
|
||||
bool mUseImage = false;
|
||||
};
|
||||
|
||||
class ConvBufCommonExecution {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ namespace OpenCL {
|
|||
// set mDequantScale mDequantOffset mNumQuantBit mFilterDataPtr from mConv2dParams
|
||||
void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<ConvolutionCommon::Int8Common> & quanCommon) {
|
||||
quanCommon = ConvolutionCommon::load(mResource->mConv2dParams, this->backend(), false, true);
|
||||
if ((mOpenCLBackend->getMemory() == BackendConfig::Memory_Low) && (mResource->mConv2dParams->quanParameter() != nullptr)) {
|
||||
if (mResource->mConv2dParams->quanParameter() != nullptr) {
|
||||
mLowMemoryFlag = true;
|
||||
} else {
|
||||
MNN_ERROR("Conv buf low memory init error.\n");
|
||||
|
|
@ -30,44 +30,34 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<Convoluti
|
|||
int numAlpha = mResource->mOutputChannel;
|
||||
// set mDequantScale mDequantOffset
|
||||
int numAlphaPack = ROUND_UP(numAlpha, 16);
|
||||
mResource->dequantScale.reset(Tensor::createDevice<float>({numAlphaPack}));
|
||||
mResource->dequantOffset.reset(Tensor::createDevice<float>({numAlphaPack}));
|
||||
mResource->dequantScale.reset(Tensor::createDevice<int32_t>({numAlphaPack}));
|
||||
mResource->dequantOffset.reset(Tensor::createDevice<int32_t>({numAlphaPack}));
|
||||
|
||||
mOpenCLBackend->onAcquireBuffer(mResource->dequantScale.get(), Backend::STATIC);
|
||||
mOpenCLBackend->onAcquireBuffer(mResource->dequantOffset.get(), Backend::STATIC);
|
||||
cl::Buffer &dequantScaleBuffer = openCLBuffer(mResource->dequantScale.get());
|
||||
cl::Buffer &dequantOffsetBuffer = openCLBuffer(mResource->dequantOffset.get());
|
||||
// transfer data from src in cpu to dst in gpu
|
||||
int bytes = mOpenCLBackend->fpBytes();
|
||||
int fpBytes = mOpenCLBackend->fpBytes();
|
||||
cl_int resBias, resScale, resOffset;
|
||||
void * dequantScaleBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantScaleBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * bytes, nullptr, nullptr, &resScale);
|
||||
void * dequantOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantOffsetBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * bytes, nullptr, nullptr, &resOffset);
|
||||
|
||||
::memset(dequantScaleBufferMap, -1, numAlphaPack * bytes);
|
||||
::memset(dequantOffsetBufferMap, 0, numAlphaPack * bytes);
|
||||
|
||||
void * dequantScaleBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantScaleBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * sizeof(int32_t), nullptr, nullptr, &resScale);
|
||||
void * dequantOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantOffsetBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * sizeof(int32_t), nullptr, nullptr, &resOffset);
|
||||
|
||||
::memset(dequantScaleBufferMap, -1, numAlphaPack * sizeof(int32_t));
|
||||
::memset(dequantOffsetBufferMap, 0, numAlphaPack * sizeof(int32_t));
|
||||
|
||||
if (dequantScaleBufferMap != nullptr && dequantOffsetBufferMap != nullptr && resScale == CL_SUCCESS && resOffset == CL_SUCCESS) {
|
||||
if (bytes == 2) {
|
||||
if (quanCommon->asymmetric) {
|
||||
for (int i = 0; i < numAlpha; ++i) {
|
||||
((half_float::half *)dequantOffsetBufferMap)[i] = (half_float::half)dequantAlpha[2 * i];
|
||||
((half_float::half *)dequantScaleBufferMap)[i] = (half_float::half)dequantAlpha[2 * i + 1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < numAlpha; ++i) {
|
||||
((half_float::half *)dequantScaleBufferMap)[i] = (half_float::half)dequantAlpha[i];
|
||||
((half_float::half *)dequantOffsetBufferMap)[i] = 0.0f;
|
||||
}
|
||||
if (quanCommon->asymmetric) {
|
||||
for (int i = 0; i < numAlpha; ++i) {
|
||||
((float *)dequantOffsetBufferMap)[i] = dequantAlpha[2 * i];
|
||||
((float *)dequantScaleBufferMap)[i] = dequantAlpha[2 * i + 1];
|
||||
}
|
||||
} else {
|
||||
if (quanCommon->asymmetric) {
|
||||
for (int i = 0; i < numAlpha; ++i) {
|
||||
((float *)dequantOffsetBufferMap)[i] = dequantAlpha[2 * i];
|
||||
((float *)dequantScaleBufferMap)[i] = dequantAlpha[2 * i + 1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < numAlpha; ++i) {
|
||||
((float *)dequantScaleBufferMap)[i] = dequantAlpha[i];
|
||||
((float *)dequantOffsetBufferMap)[i] = 0.0f;
|
||||
}
|
||||
for (int i = 0; i < numAlpha; ++i) {
|
||||
((float *)dequantScaleBufferMap)[i] = dequantAlpha[i];
|
||||
((float *)dequantOffsetBufferMap)[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -82,7 +72,7 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<Convoluti
|
|||
// set mKernelBuffer for the 1x1 kernels
|
||||
void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin, void * filterDataPtr, std::shared_ptr<ConvolutionCommon::Int8Common> & quanCommon) {
|
||||
cl_int res;
|
||||
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>({ROUND_UP(mResource->mOutputChannel, 8)/*Cout pack set to max 8*/, ROUND_UP(mResource->mInputChannel, packCin), mResource->mKernelWidth, mResource->mKernelHeight}));
|
||||
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>({ROUND_UP(mResource->mOutputChannel, packCout), ROUND_UP(mResource->mInputChannel, packCin), mResource->mKernelWidth, mResource->mKernelHeight}));
|
||||
size_t buffer_size = filterBuffer->usize() / sizeof(float);
|
||||
float *dequantAlpha = quanCommon->alpha.get();
|
||||
// shared part for all cases
|
||||
|
|
@ -93,43 +83,66 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin,
|
|||
// int4 case
|
||||
buffer_size /= 2;
|
||||
} else {/* More types to be supported. */}
|
||||
mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
auto kernelBufferPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
|
||||
if(kernelBufferPtr != nullptr && res == CL_SUCCESS){
|
||||
::memset(kernelBufferPtr, 0, buffer_size);
|
||||
|
||||
|
||||
// Use Image load weights
|
||||
void *mapPtr = nullptr;
|
||||
size_t row_pitch;
|
||||
size_t slice_pitch;
|
||||
if(UP_DIV(mResource->mInputChannel, packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, packCout) <= 16384){
|
||||
mResource->mUseImage = true;
|
||||
}
|
||||
if(mResource->mUseImage) {
|
||||
if(mNumQuantBit == 4){
|
||||
packCin *= 2;
|
||||
}
|
||||
size_t w = ROUND_UP(mResource->mOutputChannel, packCout);
|
||||
size_t h = UP_DIV(mResource->mInputChannel, packCin);
|
||||
mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_FLOAT), w, h, 0, nullptr, &res));
|
||||
if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) {
|
||||
MNN_ERROR("Alloc Image %d x %d error, code:%d \n", w, h, res);
|
||||
}
|
||||
mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapImage(*(mResource->mKernelImage.get()), true, CL_MAP_WRITE, {0, 0, 0}, {w, h, 1}, &row_pitch, &slice_pitch, nullptr, nullptr, &res);
|
||||
if(mNumQuantBit == 4){
|
||||
row_pitch *= 2;
|
||||
}
|
||||
} else{
|
||||
mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
|
||||
row_pitch = ROUND_UP(mResource->mOutputChannel, packCout) * packCin;
|
||||
}
|
||||
if(mapPtr != nullptr && res == CL_SUCCESS){
|
||||
for(int o = 0; o < mResource->mOutputChannel; o++){
|
||||
float zero = 0;
|
||||
if(quanCommon->asymmetric){
|
||||
zero = (-dequantAlpha[2 * o + 1])/dequantAlpha[2 * o];
|
||||
}
|
||||
int i = 0;
|
||||
for(; i < mResource->mInputChannel; i++){
|
||||
int bufferIdx = (o/packCout) * packCin*packCout + (i/packCin)*packCin*ROUND_UP(mResource->mOutputChannel, packCout) + (o%packCout)*packCin + (i%packCin);//(Ci/packCin, Co/packCout,packCout, packCin)
|
||||
for(; i < mResource->mInputChannel ; i++){
|
||||
int bufferIdx = (i/packCin) * row_pitch + o*packCin + (i%packCin);//(Ci/packCin, Co/packCout * packCout * packCin)
|
||||
int filterIdx = o*mResource->mInputChannel + i;
|
||||
if (mNumQuantBit == 8) {
|
||||
// int8 case
|
||||
((int8_t *)kernelBufferPtr)[bufferIdx] = (int8_t)(((int8_t *)filterDataPtr)[filterIdx]);
|
||||
((int8_t *)mapPtr)[bufferIdx] = (int8_t)(((int8_t *)filterDataPtr)[filterIdx]);
|
||||
} else if (mNumQuantBit == 4){
|
||||
// int4 case
|
||||
if (bufferIdx % 2 == 0) {
|
||||
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)((((int8_t *)filterDataPtr)[filterIdx] + 8) * 16);
|
||||
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)((((int8_t *)filterDataPtr)[filterIdx] + 8) * 16);
|
||||
} else {
|
||||
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)(((int8_t *)filterDataPtr)[filterIdx] + 8);
|
||||
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)(((int8_t *)filterDataPtr)[filterIdx] + 8);
|
||||
}
|
||||
} else {/* More types to be supported. */}
|
||||
}
|
||||
for(; i < ROUND_UP(mResource->mInputChannel, 4); i++){
|
||||
int bufferIdx = (o/packCout) * packCin*packCout + (i/packCin)*packCin*ROUND_UP(mResource->mOutputChannel, packCout) + (i%packCin)*packCout + (o%packCout);//(Ci/packCin, Co/packCout, packCin, packCout)
|
||||
for(; i < ROUND_UP(mResource->mInputChannel, packCin); i++){
|
||||
int bufferIdx = (i/packCin) * row_pitch + o*packCin + (i%packCin);//(Ci/packCin, Co/packCout * packCout * packCin)
|
||||
if (mNumQuantBit == 8) {
|
||||
// int8 case
|
||||
((int8_t *)kernelBufferPtr)[bufferIdx] = (int8_t)(zero);
|
||||
((int8_t *)mapPtr)[bufferIdx] = (int8_t)(zero);
|
||||
} else if (mNumQuantBit == 4){
|
||||
// int4 case
|
||||
if (bufferIdx % 2 == 0) {
|
||||
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)((zero + 8) * 16);
|
||||
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)((zero + 8) * 16);
|
||||
} else {
|
||||
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)(zero + 8);
|
||||
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)(zero + 8);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -138,7 +151,11 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin,
|
|||
MNN_ERROR("set1x1WeightLowMemory: Map error ptrCL == nullptr \n");
|
||||
MNN_ASSERT(false);
|
||||
}
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelBuffer.get()), kernelBufferPtr);
|
||||
if(mResource->mUseImage){
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelImage.get()), mapPtr);
|
||||
} else{
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelBuffer.get()), mapPtr);
|
||||
}
|
||||
}
|
||||
// set mFilter for the general kernels
|
||||
void ConvBufLowMemoryExecution::setGeneralWeightLowMemory(void* filterDataPtr, std::shared_ptr<ConvolutionCommon::Int8Common> & quanCommon) {
|
||||
|
|
@ -321,13 +338,11 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
|
|||
const int width = outputShape.at(2);
|
||||
const int inputChannelBlocks = UP_DIV(inputChannels, 4);
|
||||
const int outputChannelBlocks = UP_DIV(outChannel, 4);
|
||||
std::string kernelname = "gemm_conv_buf";
|
||||
int global_x = outputChannelBlocks;
|
||||
|
||||
int global_y = batch * height;
|
||||
|
||||
const int total_kernel = 3;
|
||||
std::string kernelName[total_kernel] = {"gemm_conv_c1_buf", "gemm_conv_c2_buf", "gemm_conv_c4_buf",};
|
||||
int itemC[total_kernel] = {1, 2, 4};
|
||||
const int total_kernel = 5;
|
||||
std::string kernelName[total_kernel] = {"gemm_conv_c1_buf", "gemm_conv_c2_buf", "gemm_conv_c4_buf", "gemm_conv_c1_image", "gemm_conv_c2_image"};
|
||||
int itemC[total_kernel] = {1, 2, 4, 1, 2};
|
||||
int actual_kernel = total_kernel;
|
||||
std::shared_ptr<KernelWrap> kernel[total_kernel];
|
||||
std::vector<uint32_t> globalWorkSize[total_kernel];
|
||||
|
|
@ -337,12 +352,27 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
|
|||
if(width == 1 && height == 1){
|
||||
buildOption.emplace("-DWIDTH_HEIGHT_1");
|
||||
}
|
||||
|
||||
if(inputChannels % 16 != 0){
|
||||
buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
|
||||
} else if (mResource->mUseImage && mNumQuantBit == 4 && inputChannels % 32 != 0) {
|
||||
// Image weight-int4 use load32
|
||||
buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
|
||||
}
|
||||
std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel);
|
||||
for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", kernelName[knl_idx], buildOption);
|
||||
if(batch > 1){
|
||||
global_y = UP_DIV(batch, 2) * height;
|
||||
buildOption.emplace("-DBACTH_BLOCK2");
|
||||
info += "_BATCH_BLOCK2";
|
||||
}
|
||||
int knl_idx = 0;
|
||||
actual_kernel = 3;
|
||||
if(mResource->mUseImage){
|
||||
knl_idx = 3;
|
||||
actual_kernel = total_kernel;
|
||||
}
|
||||
for (; knl_idx < actual_kernel; knl_idx++) {
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[knl_idx], buildOption);
|
||||
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
|
||||
|
||||
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outChannel, itemC[knl_idx]) * width), static_cast<uint32_t>(global_y)};
|
||||
|
|
@ -351,7 +381,11 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
|
|||
ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]);
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]);
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(input));
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelBuffer.get());
|
||||
if(mResource->mUseImage){
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelImage.get());
|
||||
}else{
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelBuffer.get());
|
||||
}
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->dequantScale.get()));
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->dequantOffset.get()));
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get()));
|
||||
|
|
@ -361,7 +395,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
|
|||
ret |= kernel[knl_idx]->get().setArg(idx++, static_cast<int>(batch));
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, static_cast<int>(height));
|
||||
ret |= kernel[knl_idx]->get().setArg(idx++, static_cast<int>(width));
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv_buf Kernel Select");
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf Kernel Select");
|
||||
std::pair<std::vector<uint32_t>, int> retTune;
|
||||
retTune = gws2dLwsTune(kernel[knl_idx], globalWorkSize[knl_idx], kernelName[knl_idx] + info, maxWorkGroupSize);
|
||||
if(min_cost.first > retTune.second) {
|
||||
|
|
@ -374,14 +408,18 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
|
|||
mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
|
||||
|
||||
|
||||
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", kernelName[min_index], buildOption);
|
||||
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[min_index], buildOption);
|
||||
//MNN_PRINT("Kernel is %d.\n", min_index);
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
|
||||
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]);
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input));
|
||||
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get());
|
||||
if(mResource->mUseImage){
|
||||
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get());
|
||||
}else{
|
||||
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get());
|
||||
}
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScale.get()));
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantOffset.get()));
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get()));
|
||||
|
|
@ -391,7 +429,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
|
|||
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(batch));
|
||||
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(height));
|
||||
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(width));
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv_buf");
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf");
|
||||
mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
|
||||
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]};
|
||||
unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]};
|
||||
|
|
@ -422,7 +460,7 @@ ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector<Tensor *>
|
|||
// prepare mDequantScale mDequantOffset mFilterDataPtr
|
||||
getInfoFromOpLowMemory(quanCommon);
|
||||
//select opt conv method
|
||||
if (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1) {
|
||||
if (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1 && conv2dCommonParams->padX() == 0 && conv2dCommonParams->padY() == 0 && conv2dCommonParams->dilateX() == 1 && conv2dCommonParams->dilateY() == 1) {
|
||||
set1x1WeightLowMemory(4, 16, mFilterDataPtr, quanCommon);
|
||||
mResource->mConv1x1Opt = true;
|
||||
}else {
|
||||
|
|
|
|||
|
|
@ -48,24 +48,13 @@ DeconvBufExecution::DeconvBufExecution(const std::vector<Tensor *> &inputs, cons
|
|||
std::shared_ptr<Tensor> filterBuffer(
|
||||
Tensor::createDevice<float>({outputChannel, inputChannel, kernelHeight, kernelWidth}));
|
||||
|
||||
int buffer_size = filterBuffer->elementSize();
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
|
||||
buffer_size *= sizeof(half_float::half);
|
||||
} else {
|
||||
buffer_size *= sizeof(float);
|
||||
}
|
||||
size_t buffer_size = filterBuffer->elementSize() * sizeof(float);
|
||||
cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_ONLY | CL_MEM_ALLOC_HOST_PTR, buffer_size);
|
||||
filterBuffer->buffer().device = (uint64_t)(&filterBufferCL);
|
||||
cl_int error;
|
||||
auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
|
||||
if(ptrCL != nullptr && error == CL_SUCCESS){
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
|
||||
for(int i=0; i<filterBuffer->elementSize(); i++) {
|
||||
((half_float::half*)ptrCL)[i] = (half_float::half)(filterDataPtrTransformed[i]);
|
||||
}
|
||||
}else{
|
||||
::memcpy(ptrCL, filterDataPtrTransformed.data(), filterBuffer->size());
|
||||
}
|
||||
::memcpy(ptrCL, filterDataPtrTransformed.data(), filterBuffer->size());
|
||||
}else{
|
||||
MNN_ERROR("Map error ptrCL == nullptr \n");
|
||||
}
|
||||
|
|
@ -75,10 +64,7 @@ DeconvBufExecution::DeconvBufExecution(const std::vector<Tensor *> &inputs, cons
|
|||
mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
|
||||
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
|
||||
|
||||
bool needTrans = false;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
|
||||
needTrans = true;
|
||||
}
|
||||
bool needTrans = true;
|
||||
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), needTrans);
|
||||
mResource->mBuildOptions.emplace("-DBIAS");
|
||||
if (conv2dCommonParams->relu() == true) {
|
||||
|
|
|
|||
|
|
@ -39,24 +39,13 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector<Tensor *>
|
|||
mResource->mFilter.reset(Tensor::createDevice<float>({1, ROUND_UP(filterImageShape[1], 2)/*for kernel C8 read*/, 1, 4 * filterImageShape[0]}));
|
||||
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>(filterShape));
|
||||
|
||||
int buffer_size = filterBuffer->elementSize();
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
|
||||
buffer_size *= sizeof(half_float::half);
|
||||
} else {
|
||||
buffer_size *= sizeof(float);
|
||||
}
|
||||
size_t buffer_size = filterBuffer->elementSize() * sizeof(float);
|
||||
cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
|
||||
filterBuffer->buffer().device = (uint64_t)(&filterBufferCL);
|
||||
cl_int error;
|
||||
auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
|
||||
if(ptrCL != nullptr && error == CL_SUCCESS){
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
|
||||
for (int i = 0; i < filterBuffer->elementSize(); i++) {
|
||||
((half_float::half *)ptrCL)[i] = (half_float::half)(filterDataPtr[i]);
|
||||
}
|
||||
} else {
|
||||
::memcpy(ptrCL, filterDataPtr, filterBuffer->size());
|
||||
}
|
||||
::memcpy(ptrCL, filterDataPtr, filterBuffer->size());
|
||||
}else{
|
||||
MNN_ERROR("Map error ptrCL == nullptr \n");
|
||||
}
|
||||
|
|
@ -65,10 +54,7 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector<Tensor *>
|
|||
mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
|
||||
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
|
||||
|
||||
bool needTrans = false;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
|
||||
needTrans = true;
|
||||
}
|
||||
bool needTrans = true;
|
||||
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::DW_CONV2D_FILTER, mResource->mFilter.get(), needTrans);
|
||||
|
||||
if (mResource->mConv2dCommonParams->relu() == true) {
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector<Tensor *> &inputs, c
|
|||
Tensor *input = inputs[0];
|
||||
Tensor *output = outputs[0];
|
||||
auto runtime = ((OpenCLBackend *)backend())->getOpenCLRuntime();
|
||||
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
|
||||
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize) / 4;
|
||||
|
||||
std::vector<int> inputShape = tensorShapeFormat(input);
|
||||
std::vector<int> outputShape = tensorShapeFormat(output);
|
||||
|
|
@ -114,6 +114,14 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector<Tensor *> &inputs, c
|
|||
inner_size *= inputs.at(0)->length(i);
|
||||
}
|
||||
|
||||
if (group_ > 1) {
|
||||
outter_size = inputs[0]->length(0) * group_;
|
||||
inner_size = 1;
|
||||
for (int i = 1; i < rank; i++) {
|
||||
inner_size *= inputs[0]->length(i);
|
||||
}
|
||||
inner_size /= group_;
|
||||
}
|
||||
|
||||
std::set<std::string> buildOptions;
|
||||
if(RMSNorm){
|
||||
|
|
@ -150,6 +158,111 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector<Tensor *> &inputs, c
|
|||
mGWS = {static_cast<uint32_t>(local_size),
|
||||
static_cast<uint32_t>(1),
|
||||
static_cast<uint32_t>(inputBatch)};
|
||||
} else if(inner_size == inputWidth * inputHeight * inputChannels / group_ && outter_size == inputBatch * group_){
|
||||
mUnits.clear();
|
||||
mUnits.resize(3);
|
||||
std::vector<int> inputShape = tensorShapeFormat(inputs[0]);
|
||||
int inputWH[] = {inputShape[2], inputShape[1]};
|
||||
int region[] = {inputShape[0], UP_DIV(inputShape[3], 4), inputShape[1], inputShape[2]};
|
||||
|
||||
mInputPlain = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{inputShape[0], inputShape[3], ROUND_UP(inputShape[1] * inputShape[2], 4), 1}, Tensor::CAFFE));
|
||||
mOpenCLBackend->onAcquireBuffer(mInputPlain.get(), Backend::DYNAMIC);
|
||||
mOutputPlain = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{inputShape[0], inputShape[3], ROUND_UP(inputShape[1] * inputShape[2], 4), 1}, Tensor::CAFFE));
|
||||
mOpenCLBackend->onAcquireBuffer(mOutputPlain.get(), Backend::DYNAMIC);
|
||||
|
||||
// convert nc4hw4 to nchw
|
||||
{
|
||||
auto &unit = mUnits[0];
|
||||
unit.kernel = runtime->buildKernel("buffer_convert_buf", "nc4hw4_buffer_to_nchw_buffer", {}, inputs[0], outputs[0]);
|
||||
|
||||
mGWS = {(uint32_t)(UP_DIV(region[3] * region[1], 16) * 16),
|
||||
(uint32_t)(UP_DIV(region[2] * region[0], 16) * 16)};
|
||||
mLWS = {16, 16};
|
||||
unit.globalWorkSize = {mGWS[0], mGWS[1]};
|
||||
unit.localWorkSize = {mLWS[0], mLWS[1]};
|
||||
|
||||
int global_dim0 = region[3] * region[1];
|
||||
int global_dim1 = region[2] * region[0];
|
||||
|
||||
//MNN_CHECK_CL_SUCCESS
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel->get().setArg(idx++, global_dim0);
|
||||
ret |= unit.kernel->get().setArg(idx++, global_dim1);
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get()));
|
||||
ret |= unit.kernel->get().setArg(idx++, inputWH[1]);
|
||||
ret |= unit.kernel->get().setArg(idx++, inputWH[0]);
|
||||
ret |= unit.kernel->get().setArg(idx++, inputShape[3]);
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input));
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, convert nc4hw4 to nchw");
|
||||
|
||||
mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS);
|
||||
}
|
||||
// do group layernorm
|
||||
{
|
||||
auto &unit = mUnits[1];
|
||||
kernelName = "layernorm_plain_buf";
|
||||
local_size = getLocalSize(UP_DIV(inner_size, 4), MaxLocalSize);
|
||||
buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size));
|
||||
unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions);
|
||||
|
||||
mGWS = {static_cast<uint32_t>(local_size),
|
||||
static_cast<uint32_t>(1),
|
||||
static_cast<uint32_t>(outter_size)};
|
||||
|
||||
mLWS = {static_cast<uint32_t>(local_size), 1, 1};
|
||||
|
||||
unit.globalWorkSize = {mGWS[0], mGWS[1], mGWS[2]};
|
||||
unit.localWorkSize = {mLWS[0], mLWS[1], mLWS[2]};
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel->get().setArg(idx++, mGWS[0]);
|
||||
ret |= unit.kernel->get().setArg(idx++, mGWS[1]);
|
||||
ret |= unit.kernel->get().setArg(idx++, mGWS[2]);
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get()));
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get()));
|
||||
ret |= unit.kernel->get().setArg(idx++, static_cast<int32_t>(inner_size));
|
||||
ret |= unit.kernel->get().setArg(idx++, static_cast<int32_t>(outter_size));
|
||||
if(has_gamma_beta_){
|
||||
ret |= unit.kernel->get().setArg(idx++, *mGammaBuffer.get());
|
||||
ret |= unit.kernel->get().setArg(idx++, *mBetaBuffer.get());
|
||||
}
|
||||
ret |= unit.kernel->get().setArg(idx++, epsilon_);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, do group layernorm");
|
||||
mOpenCLBackend->recordKernel3d(unit.kernel, mGWS, mLWS);
|
||||
}
|
||||
// convert nchw to nc4hw4
|
||||
{
|
||||
auto &unit = mUnits[2];
|
||||
|
||||
unit.kernel = runtime->buildKernel("buffer_convert_buf", "nchw_buffer_to_nc4hw4_buffer", {}, inputs[0], outputs[0]);
|
||||
mLWS = {16, 16};
|
||||
mGWS = {(uint32_t)UP_DIV(region[3] * region[1], 16) * 16,
|
||||
(uint32_t)UP_DIV(region[2] * region[0], 16) * 16};
|
||||
|
||||
unit.globalWorkSize = {mGWS[0], mGWS[1]};
|
||||
unit.localWorkSize = {mLWS[0], mLWS[1]};
|
||||
|
||||
int global_dim0 = region[3] * region[1];
|
||||
int global_dim1 = region[2] * region[0];
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel->get().setArg(idx++, global_dim0);
|
||||
ret |= unit.kernel->get().setArg(idx++, global_dim1);
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get()));
|
||||
ret |= unit.kernel->get().setArg(idx++, inputWH[1]);
|
||||
ret |= unit.kernel->get().setArg(idx++, inputWH[0]);
|
||||
ret |= unit.kernel->get().setArg(idx++, inputShape[3]);
|
||||
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output));
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, convert nchw to nc4hw4");
|
||||
mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS);
|
||||
}
|
||||
|
||||
mOpenCLBackend->onReleaseBuffer(mInputPlain.get(), Backend::DYNAMIC);
|
||||
mOpenCLBackend->onReleaseBuffer(mOutputPlain.get(), Backend::DYNAMIC);
|
||||
return NO_ERROR;
|
||||
}
|
||||
mLWS = {static_cast<uint32_t>(local_size), 1, 1};
|
||||
|
||||
|
|
@ -189,10 +302,6 @@ public:
|
|||
TensorUtils::setTensorSupportPack(outputs[i], false);
|
||||
}
|
||||
const auto* layer_norm_param = op->main_as_LayerNorm();
|
||||
int group = layer_norm_param->group();
|
||||
if(group > 1){
|
||||
return nullptr;
|
||||
}
|
||||
return new LayerNormBufExecution(inputs, op, backend);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ private:
|
|||
|
||||
std::shared_ptr<cl::Buffer> mGammaBuffer;
|
||||
std::shared_ptr<cl::Buffer> mBetaBuffer;
|
||||
std::shared_ptr<Tensor> mInputPlain, mOutputPlain;
|
||||
bool has_gamma_beta_ = false;
|
||||
uint32_t mMaxWorkGroupSize;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -208,7 +208,8 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector<Tensor *> &inp
|
|||
const int Width = Shape.at(2);
|
||||
const int Height = Shape.at(1);
|
||||
const int Batch = Shape.at(0);
|
||||
mTmpTensors[i] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, ROUND_UP(Height, 4), ROUND_UP(Width, 4)}, Tensor::CAFFE));
|
||||
mTmpTensors[i] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}, Tensor::CAFFE));
|
||||
|
||||
mOpenCLBackend->onAcquireBuffer(mTmpTensors[i].get(), Backend::DYNAMIC);
|
||||
|
||||
Unit unit;
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ ErrorCode SoftmaxBufExecution::onEncode(const std::vector<Tensor *> &inputs, con
|
|||
|
||||
const auto dims = input->buffer().dimensions;
|
||||
auto runtime = mOpenCLBackend->getOpenCLRuntime();
|
||||
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
|
||||
|
||||
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize) / 4;
|
||||
|
||||
int inside = 1;
|
||||
int outside = 1;
|
||||
int channel = 1;
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue