MNN:Sync Sync Internal 2.9.0

This commit is contained in:
xiaying 2024-05-11 19:17:02 +08:00
parent ba4ecd9792
commit 7cad2ee83f
236 changed files with 35800 additions and 5499 deletions

View File

@ -58,9 +58,18 @@ option(MNN_ENABLE_COVERAGE "Build with coverage enable" OFF)
option(MNN_BUILD_PROTOBUFFER "Build with protobuffer in MNN" ON)
option(MNN_BUILD_OPENCV "Build OpenCV api in MNN." OFF)
option(MNN_BUILD_LLM "Build llm library based MNN." OFF)
option(MNN_BUILD_DIFFUSION "Build diffusion demo based MNN." OFF)
option(MNN_INTERNAL "Build with MNN internal features, such as model authentication, metrics logging" OFF)
option(MNN_JNI "Build MNN Jni for java to use" OFF)
IF (OHOS)
include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake)
export_headers(DIR ${CMAKE_SOURCE_DIR}/include/MNN)
IF (MNN_BUILD_OPENCV)
export_headers(DIR ${CMAKE_SOURCE_DIR}/tools/cv/include/cv)
ENDIF()
ENDIF()
IF (NOT DEFINED MNN_USE_SPARSE_COMPUTE)
set(MNN_USE_SPARSE_COMPUTE ON)
ENDIF()
@ -263,12 +272,12 @@ endif()
option(MNN_USE_CPP11 "Enable MNN use c++11" ON)
if (NOT MSVC)
if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
elseif(MNN_USE_CPP11)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x")
@ -493,7 +502,7 @@ endif()
IF(MNN_COREML)
add_definitions(-DMNN_COREML_ENABLED=1)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/source/backend/coreml/)
IF(MNN_SEP_BUILD)
list(APPEND MNN_DEPS MNNCoreML)
list(APPEND MNN_EXTRA_DEPENDS MNNCoreML)
@ -631,7 +640,7 @@ ENDIF()
IF(MNN_BUILD_LLM)
# add_definitions(-DMNN_BUILD_LLM)
include(${CMAKE_CURRENT_LIST_DIR}/llm/CMakeLists.txt)
include(${CMAKE_CURRENT_LIST_DIR}/transformers/llm/engine/CMakeLists.txt)
ENDIF()
# NPU
@ -755,6 +764,9 @@ list(REMOVE_ITEM MNN_TARGETS MNN)
IF(MNN_BUILD_DEMO)
include(${CMAKE_CURRENT_LIST_DIR}/demo/exec/CMakeLists.txt)
ENDIF()
IF(MNN_BUILD_DIFFUSION AND MNN_BUILD_OPENCV AND MNN_IMGCODECS)
include(${CMAKE_CURRENT_LIST_DIR}/transformers/diffusion/CMakeLists.txt)
ENDIF()
IF(MNN_BUILD_TOOLS)
include(${CMAKE_CURRENT_LIST_DIR}/tools/cpp/CMakeLists.txt)
ENDIF()

View File

@ -127,7 +127,7 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
auto modelBuffer = revertor->getBuffer();
const auto bufferSize = revertor->getBufferSize();
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize));
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize), MNN::Interpreter::destroy);
revertor.reset();
net->setSessionMode(MNN::Interpreter::Session_Release);
MNN::ScheduleConfig config;

View File

@ -80,6 +80,7 @@ MNN使用CMake构建项目CMake中的宏定义列表如下
| MNN_OPENCV_BENCH | 构建MNN的OpenCV功能是否开启性能benchmark默认为`OFF` |
| MNN_VULKAN_IMAGE | 构建MNN的Vulkan后端时采用Image内存模式以便支持FP16和部分移动端上GPU的加速默认为`ON` |
| MNN_LOW_MEMORY | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` |
| MNN_SUPPORT_RENDER | 是否支持图形渲染相关算子实现,默认为 `OFF` |
| MNN_SUPPORT_RENDER | 是否支持图形渲染相关算子实现,默认为 `OFF` |
| MNN_SUPPORT_TRANSFORMER_FUSE | 是否支持Fuse Transformer相关OP实现默认为 `OFF` |
| MNN_BUILD_LLM | 是否构建基于MNN的llm库和demo默认为`OFF` |
| MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo需要打开MNN_BUILD_OPENCV和MNN_IMGCODECS宏使用 默认为`OFF` |

View File

@ -60,6 +60,7 @@
- `winogradExample.out` winograd示例
- `fuseTest` 测试 GPU 自定义算子的功能,目前仅支持 Vulkan Buffer 模式
- `GpuInterTest.out` 测试 GPU 内存输入的功能,目前仅支持 OpenCL Buffer 模式与 OpenGL texture 模式,编译时许打开 MNN_OPENCL 与 MNN_OPENGL
- `LoRA` 将LorA权重添加到模型权重中
## Benchmark工具
- 相关编译选项
- `MNN_BUILD_BENCHMARK` 是否编译Benchmark工具

View File

@ -27,6 +27,7 @@ namespace Express {
void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& config, int numberThread) {
std::lock_guard<std::mutex> _l(mMutex);
if(type == MNN_FORWARD_AUTO) {
ScheduleConfig sConfig;
sConfig.type = type;
@ -41,10 +42,12 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig&
info.numThread = 4;
}
mAttr->firstType = std::make_pair(type, info.numThread);
info.user = (BackendConfig*)&config;
std::shared_ptr<Runtime> bn(creator->onCreate(info));
mRuntimes[mAttr->firstType] = bn;
auto firstIter = mRuntimes.find(mAttr->firstType);
if (firstIter == mRuntimes.end()) {
info.user = (BackendConfig*)&config;
std::shared_ptr<Runtime> bn(creator->onCreate(info));
mRuntimes[mAttr->firstType] = bn;
}
} else {
auto creator = MNNGetExtraRuntimeCreator(type);
if (nullptr == creator) {
@ -56,11 +59,14 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig&
Backend::Info info;
info.type = type;
mAttr->firstType = std::make_pair(type, numberThread);
info.mode = Backend::Info::DIRECT;
info.numThread = numberThread;
info.user = (BackendConfig*)&config;
std::shared_ptr<Runtime> bn(creator->onCreate(info));
mRuntimes[mAttr->firstType] = bn;
auto firstIter = mRuntimes.find(mAttr->firstType);
if (firstIter == mRuntimes.end()) {
info.mode = Backend::Info::DIRECT;
info.numThread = numberThread;
info.user = (BackendConfig*)&config;
std::shared_ptr<Runtime> bn(creator->onCreate(info));
mRuntimes[mAttr->firstType] = bn;
}
}
_refreshRuntime();
}
@ -155,6 +161,10 @@ std::shared_ptr<Executor> Executor::newExecutor(MNNForwardType type,
const BackendConfig& config,
int numberThread) {
auto creator = MNNGetExtraRuntimeCreator(type);
if(nullptr == creator) {
MNN_ERROR("Don't support %d\n", type);
return nullptr;
}
Backend::Info info;
info.type = type;
info.numThread = numberThread;

View File

@ -98,6 +98,21 @@ static std::vector<std::shared_ptr<BufferStorage>> preRearrangeWeights( // NOLIN
}
break;
}
case MNN::OpType_Attention: {
exe.reset(backend->onCreate({}, {}, op));
if (exe.get() == nullptr) {
exe.reset(backupBackend->onCreate({}, {}, op));
}
if (nullptr == exe) {
break;
}
// The exe can't clone
if (!exe->onClone(nullptr, op, nullptr)) {
exe = nullptr;
break;
}
break;
}
default: {
break;
}

View File

@ -68,7 +68,7 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR_IMP(x) #x
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 2
#define MNN_VERSION_MINOR 8
#define MNN_VERSION_PATCH 4
#define MNN_VERSION_MINOR 9
#define MNN_VERSION_PATCH 0
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */

16
project/harmony/build_64.sh Executable file
View File

@ -0,0 +1,16 @@
#!/bin/bash
cmake ../../../ \
-DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake \
-DCMAKE_BUILD_TYPE=Release \
-DOHOS_ARCH="arm64-v8a" \
-DOHOS_STL=c++_static \
-DMNN_USE_LOGCAT=false \
-DMNN_BUILD_BENCHMARK=ON \
-DMNN_USE_SSE=OFF \
-DMNN_SUPPORT_BF16=OFF \
-DMNN_BUILD_TEST=ON \
-DOHOS_PLATFORM_LEVEL=9 \
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
-DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3
make -j4

16
project/harmony/updateTest.sh Executable file
View File

@ -0,0 +1,16 @@
#!/bin/bash
DIR=yanxing
make -j16
hdc file send ./libMNN.so /data/local/tmp/$DIR/libMNN.so
hdc file send ./libMNN_Express.so /data/local/tmp/$DIR/libMNN_Express.so
hdc file send ./MNNV2Basic.out /data/local/tmp/$DIR/MNNV2Basic.out
hdc file send ./ModuleBasic.out /data/local/tmp/$DIR/ModuleBasic.out
# hdc shell "cd /data/local/tmp/$DIR && rm -r output"
# hdc shell "cd /data/local/tmp/$DIR && mkdir output"
hdc file send ./unitTest.out /data/local/tmp/$DIR/unitTest.out
hdc file send ./testModel.out /data/local/tmp/$DIR/testModel.out
hdc file send ./testModelWithDescribe.out /data/local/tmp/$DIR/testModelWithDescribe.out
hdc file send ./backendTest.out /data/local/tmp/$DIR/backendTest.out
hdc file send ./timeProfile.out /data/local/tmp/$DIR/timeProfile.out
hdc file send ./run_test.out /data/local/tmp/$DIR/run_test.out

View File

@ -174,7 +174,6 @@
489D7A7B2550FDC800AD896A /* MetalUnary.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2A2550FDC800AD896A /* MetalUnary.hpp */; };
489D7A7D2550FDC900AD896A /* MetalConvolution.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A2C2550FDC800AD896A /* MetalConvolution.mm */; };
489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */; };
489D7A7F2550FDC900AD896A /* MetalReLU.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2E2550FDC800AD896A /* MetalReLU.hpp */; };
489D7A802550FDC900AD896A /* MetalEltwise.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */; };
489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A302550FDC800AD896A /* MetalPooling.hpp */; };
489D7A822550FDC900AD896A /* MetalPReLU.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A312550FDC800AD896A /* MetalPReLU.hpp */; };
@ -184,7 +183,6 @@
489D7A8A2550FDC900AD896A /* MetalConvolutionDepthwise.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A392550FDC800AD896A /* MetalConvolutionDepthwise.mm */; };
489D7A8B2550FDC900AD896A /* MetalConvolutionWinograd.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A3A2550FDC800AD896A /* MetalConvolutionWinograd.hpp */; };
489D7A8C2550FDC900AD896A /* MetalDeconvolution.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A3B2550FDC800AD896A /* MetalDeconvolution.mm */; };
489D7A8D2550FDC900AD896A /* MetalReLU.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A3C2550FDC800AD896A /* MetalReLU.mm */; };
489D7A8E2550FDC900AD896A /* MetalPooling.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A3D2550FDC800AD896A /* MetalPooling.mm */; };
489D7A902550FDC900AD896A /* MetalConvolution.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A3F2550FDC800AD896A /* MetalConvolution.hpp */; };
489D7A912550FDC900AD896A /* MetalScale.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A402550FDC800AD896A /* MetalScale.mm */; };
@ -746,6 +744,7 @@
9558333D29B0947300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558333C29B0947300488807 /* MNNGelu.S */; };
9558334729B09A2300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334629B09A2300488807 /* MNNGelu.S */; };
9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334A29B09A7B00488807 /* MNNGeluFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */; };
956F52E12AB2D692004B13D9 /* ImageProcessUtils.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */; };
956F52E32AB2D6A1004B13D9 /* ImageProcessUtils.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */; };
958375352A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S in Sources */ = {isa = PBXBuildFile; fileRef = 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */; };
@ -1005,7 +1004,6 @@
489D7A2A2550FDC800AD896A /* MetalUnary.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalUnary.hpp; sourceTree = "<group>"; };
489D7A2C2550FDC800AD896A /* MetalConvolution.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolution.mm; sourceTree = "<group>"; };
489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MNNMetalContext.mm; sourceTree = "<group>"; };
489D7A2E2550FDC800AD896A /* MetalReLU.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalReLU.hpp; sourceTree = "<group>"; };
489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalEltwise.hpp; sourceTree = "<group>"; };
489D7A302550FDC800AD896A /* MetalPooling.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalPooling.hpp; sourceTree = "<group>"; };
489D7A312550FDC800AD896A /* MetalPReLU.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalPReLU.hpp; sourceTree = "<group>"; };
@ -1015,7 +1013,6 @@
489D7A392550FDC800AD896A /* MetalConvolutionDepthwise.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolutionDepthwise.mm; sourceTree = "<group>"; };
489D7A3A2550FDC800AD896A /* MetalConvolutionWinograd.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalConvolutionWinograd.hpp; sourceTree = "<group>"; };
489D7A3B2550FDC800AD896A /* MetalDeconvolution.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalDeconvolution.mm; sourceTree = "<group>"; };
489D7A3C2550FDC800AD896A /* MetalReLU.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalReLU.mm; sourceTree = "<group>"; };
489D7A3D2550FDC800AD896A /* MetalPooling.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalPooling.mm; sourceTree = "<group>"; };
489D7A3F2550FDC800AD896A /* MetalConvolution.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalConvolution.hpp; sourceTree = "<group>"; };
489D7A402550FDC800AD896A /* MetalScale.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalScale.mm; sourceTree = "<group>"; };
@ -1587,6 +1584,7 @@
9558333C29B0947300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
9558334629B09A2300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
9558334A29B09A7B00488807 /* MNNGeluFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNGeluFP16.S; path = ../../../arm82/asm/arm64/MNNGeluFP16.S; sourceTree = "<group>"; };
9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryLayernorm.cpp; sourceTree = "<group>"; };
956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ImageProcessUtils.cpp; sourceTree = "<group>"; };
956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ImageProcessUtils.hpp; sourceTree = "<group>"; };
958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; path = arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; sourceTree = "<group>"; };
@ -1842,6 +1840,7 @@
4819FB2624C139690050BD09 /* GeometryLSTM.cpp */,
4819FB2424C139680050BD09 /* GeometryPoolGrad.cpp */,
4819FB2A24C139690050BD09 /* GeometryReduce.cpp */,
9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */,
489404DD24A2FC2B001E456C /* GeometryReverseSequence.cpp */,
48FD0349246AA40300456AF5 /* GeometryConvert.cpp */,
48FD12BD2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp */,
@ -2150,7 +2149,6 @@
489D7A2A2550FDC800AD896A /* MetalUnary.hpp */,
489D7A2C2550FDC800AD896A /* MetalConvolution.mm */,
489D7A2D2550FDC800AD896A /* MNNMetalContext.mm */,
489D7A2E2550FDC800AD896A /* MetalReLU.hpp */,
489D7A2F2550FDC800AD896A /* MetalEltwise.hpp */,
489D7A302550FDC800AD896A /* MetalPooling.hpp */,
489D7A312550FDC800AD896A /* MetalPReLU.hpp */,
@ -2160,7 +2158,6 @@
489D7A392550FDC800AD896A /* MetalConvolutionDepthwise.mm */,
489D7A3A2550FDC800AD896A /* MetalConvolutionWinograd.hpp */,
489D7A3B2550FDC800AD896A /* MetalDeconvolution.mm */,
489D7A3C2550FDC800AD896A /* MetalReLU.mm */,
489D7A3D2550FDC800AD896A /* MetalPooling.mm */,
489D7A3F2550FDC800AD896A /* MetalConvolution.hpp */,
489D7A402550FDC800AD896A /* MetalScale.mm */,
@ -3070,7 +3067,6 @@
EBECA39924643D320062C7A3 /* Arm82Relu.hpp in Headers */,
4838EA7C2611BFE20027232C /* CPUGridSample.hpp in Headers */,
92FF03A523AA0B5A00AC97F6 /* DeconvolutionWithStride.hpp in Headers */,
489D7A7F2550FDC900AD896A /* MetalReLU.hpp in Headers */,
92FF03D123AA0B5A00AC97F6 /* CPUTopKV2.hpp in Headers */,
92FF033F23AA0B5A00AC97F6 /* CPUArgMax.hpp in Headers */,
92FF034C23AA0B5A00AC97F6 /* CPUSetDiff1D.hpp in Headers */,
@ -3602,7 +3598,6 @@
486E1A9924F5078D00C16006 /* CPURandomUniform.cpp in Sources */,
92FF02C823AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */,
92FF045C23AA0B7100AC97F6 /* ShapeBroadcastTo.cpp in Sources */,
489D7A8D2550FDC900AD896A /* MetalReLU.mm in Sources */,
48747D49245D9D24000B9709 /* RuntimeFactory.cpp in Sources */,
92FF02AE23AA0B5A00AC97F6 /* CPUProposal.cpp in Sources */,
92FF042723AA0B7100AC97F6 /* ShapeMatMul.cpp in Sources */,
@ -3627,6 +3622,7 @@
92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */,
92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */,
CECF8C87299CAD9400D3875B /* sds.c in Sources */,
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */,
4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */,
CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */,
92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */,
@ -4131,7 +4127,7 @@
CODE_SIGN_STYLE = Automatic;
DEAD_CODE_STRIPPING = YES;
DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = 6G7464HHUS;
DEVELOPMENT_TEAM = Q48UX93J22;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
@ -4218,7 +4214,7 @@
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6G7464HHUS;
DEVELOPMENT_TEAM = Q48UX93J22;
GCC_ENABLE_CPP_EXCEPTIONS = NO;
GCC_ENABLE_CPP_RTTI = NO;
HEADER_SEARCH_PATHS = (

View File

@ -17,6 +17,10 @@ option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
if (OHOS)
include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake)
endif()
if(PYMNN_INTERNAL_SERVING)
file(GLOB_RECURSE SRC ${CMAKE_CURRENT_LIST_DIR}/src/MNN.cc
${CMAKE_CURRENT_LIST_DIR}/src/internal/monitor_service.cc
@ -185,12 +189,20 @@ if(WIN32 OR APPLE OR CMAKE_SYSTEM_NAME MATCHES "^Linux")
else()
target_include_directories(mnnpybridge PRIVATE ${MNN_DIR}/pymnn/src ${MNN_DIR}/pymnn/android/src/main/c/include)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${MNN_DIR}/pymnn/android/src/main/jniLibs/${ANDROID_ABI})
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
if(PYMNN_USE_ALINNPYTHON)
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
endif()
if(PYMNN_NUMPY_USABLE)
target_link_libraries(mnnpybridge PRIVATE numpy_python)
if (OHOS)
target_link_libraries(mnnpybridge PRIVATE tcpkg::mnn)
if(PYMNN_USE_ALINNPYTHON)
target_link_libraries(mnnpybridge PRIVATE tcpkg::alinnpython)
endif()
export_headers(DIR ${CMAKE_SOURCE_DIR}/pip_package/MNN)
else()
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
if(PYMNN_USE_ALINNPYTHON)
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
endif()
if(PYMNN_NUMPY_USABLE)
target_link_libraries(mnnpybridge PRIVATE numpy_python)
endif()
endif()
endif()

View File

@ -24,6 +24,9 @@ struct ExtraT;
struct StringVec;
struct StringVecT;
struct AttentionParam;
struct AttentionParamT;
struct FmhaV2Param;
struct FmhaV2ParamT;
@ -69,6 +72,8 @@ inline const flatbuffers::TypeTable *ExtraTypeTable();
inline const flatbuffers::TypeTable *StringVecTypeTable();
inline const flatbuffers::TypeTable *AttentionParamTypeTable();
inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable();
inline const flatbuffers::TypeTable *FmhcaParamTypeTable();
@ -261,6 +266,7 @@ enum OpType {
OpType_BatchNorm = 267,
OpType_ConvTranspose3D = 268,
OpType_ZeroGrad = 269,
OpType_Attention = 299,
OpType_FmhaV2 = 300,
OpType_Fmhca = 301,
OpType_SeqLen2Spatial = 302,
@ -281,7 +287,7 @@ enum OpType {
OpType_MAX = OpType_GridSample
};
inline const OpType (&EnumValuesOpType())[181] {
inline const OpType (&EnumValuesOpType())[182] {
static const OpType values[] = {
OpType_AbsVal,
OpType_QuantizedAdd,
@ -448,6 +454,7 @@ inline const OpType (&EnumValuesOpType())[181] {
OpType_BatchNorm,
OpType_ConvTranspose3D,
OpType_ZeroGrad,
OpType_Attention,
OpType_FmhaV2,
OpType_Fmhca,
OpType_SeqLen2Spatial,
@ -769,7 +776,7 @@ inline const char * const *EnumNamesOpType() {
"",
"",
"",
"",
"Attention",
"FmhaV2",
"Fmhca",
"SeqLen2Spatial",
@ -1185,11 +1192,12 @@ enum OpParameter {
OpParameter_GroupNorm = 95,
OpParameter_FmhaV2Param = 96,
OpParameter_FmhcaParam = 97,
OpParameter_AttentionParam = 98,
OpParameter_MIN = OpParameter_NONE,
OpParameter_MAX = OpParameter_FmhcaParam
OpParameter_MAX = OpParameter_AttentionParam
};
inline const OpParameter (&EnumValuesOpParameter())[98] {
inline const OpParameter (&EnumValuesOpParameter())[99] {
static const OpParameter values[] = {
OpParameter_NONE,
OpParameter_QuantizedAdd,
@ -1288,7 +1296,8 @@ inline const OpParameter (&EnumValuesOpParameter())[98] {
OpParameter_CumSum,
OpParameter_GroupNorm,
OpParameter_FmhaV2Param,
OpParameter_FmhcaParam
OpParameter_FmhcaParam,
OpParameter_AttentionParam
};
return values;
}
@ -1393,13 +1402,14 @@ inline const char * const *EnumNamesOpParameter() {
"GroupNorm",
"FmhaV2Param",
"FmhcaParam",
"AttentionParam",
nullptr
};
return names;
}
inline const char *EnumNameOpParameter(OpParameter e) {
if (e < OpParameter_NONE || e > OpParameter_FmhcaParam) return "";
if (e < OpParameter_NONE || e > OpParameter_AttentionParam) return "";
const size_t index = static_cast<int>(e);
return EnumNamesOpParameter()[index];
}
@ -1796,6 +1806,10 @@ template<> struct OpParameterTraits<FmhcaParam> {
static const OpParameter enum_value = OpParameter_FmhcaParam;
};
template<> struct OpParameterTraits<AttentionParam> {
static const OpParameter enum_value = OpParameter_AttentionParam;
};
struct OpParameterUnion {
OpParameter type;
void *value;
@ -2603,6 +2617,14 @@ struct OpParameterUnion {
return type == OpParameter_FmhcaParam ?
reinterpret_cast<const FmhcaParamT *>(value) : nullptr;
}
AttentionParamT *AsAttentionParam() {
return type == OpParameter_AttentionParam ?
reinterpret_cast<AttentionParamT *>(value) : nullptr;
}
const AttentionParamT *AsAttentionParam() const {
return type == OpParameter_AttentionParam ?
reinterpret_cast<const AttentionParamT *>(value) : nullptr;
}
};
bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type);
@ -2900,6 +2922,60 @@ inline flatbuffers::Offset<StringVec> CreateStringVec(
flatbuffers::Offset<StringVec> CreateStringVec(flatbuffers::FlatBufferBuilder &_fbb, const StringVecT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct AttentionParamT : public flatbuffers::NativeTable {
typedef AttentionParam TableType;
bool kv_cache;
AttentionParamT()
: kv_cache(true) {
}
};
struct AttentionParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef AttentionParamT NativeTableType;
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
return AttentionParamTypeTable();
}
bool kv_cache() const {
return GetField<uint8_t>(4, 1) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, 4) &&
verifier.EndTable();
}
AttentionParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(AttentionParamT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<AttentionParam> Pack(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct AttentionParamBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_kv_cache(bool kv_cache) {
fbb_.AddElement<uint8_t>(4, static_cast<uint8_t>(kv_cache), 1);
}
explicit AttentionParamBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
AttentionParamBuilder &operator=(const AttentionParamBuilder &);
flatbuffers::Offset<AttentionParam> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<AttentionParam>(end);
return o;
}
};
inline flatbuffers::Offset<AttentionParam> CreateAttentionParam(
flatbuffers::FlatBufferBuilder &_fbb,
bool kv_cache = true) {
AttentionParamBuilder builder_(_fbb);
builder_.add_kv_cache(kv_cache);
return builder_.Finish();
}
flatbuffers::Offset<AttentionParam> CreateAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct FmhaV2ParamT : public flatbuffers::NativeTable {
typedef FmhaV2Param TableType;
int32_t heads;
@ -3784,6 +3860,9 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FmhcaParam *main_as_FmhcaParam() const {
return main_type() == OpParameter_FmhcaParam ? static_cast<const FmhcaParam *>(main()) : nullptr;
}
const AttentionParam *main_as_AttentionParam() const {
return main_type() == OpParameter_AttentionParam ? static_cast<const AttentionParam *>(main()) : nullptr;
}
const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(10);
}
@ -4209,6 +4288,10 @@ template<> inline const FmhcaParam *Op::main_as<FmhcaParam>() const {
return main_as_FmhcaParam();
}
template<> inline const AttentionParam *Op::main_as<AttentionParam>() const {
return main_as_AttentionParam();
}
struct OpBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -5006,6 +5089,32 @@ inline flatbuffers::Offset<StringVec> CreateStringVec(flatbuffers::FlatBufferBui
_data);
}
inline AttentionParamT *AttentionParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new AttentionParamT();
UnPackTo(_o, _resolver);
return _o;
}
inline void AttentionParam::UnPackTo(AttentionParamT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = kv_cache(); _o->kv_cache = _e; };
}
inline flatbuffers::Offset<AttentionParam> AttentionParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateAttentionParam(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<AttentionParam> CreateAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AttentionParamT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _kv_cache = _o->kv_cache;
return MNN::CreateAttentionParam(
_fbb,
_kv_cache);
}
inline FmhaV2ParamT *FmhaV2Param::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new FmhaV2ParamT();
UnPackTo(_o, _resolver);
@ -5902,6 +6011,10 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj,
auto ptr = reinterpret_cast<const FmhcaParam *>(obj);
return verifier.VerifyTable(ptr);
}
case OpParameter_AttentionParam: {
auto ptr = reinterpret_cast<const AttentionParam *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -6308,6 +6421,10 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f
auto ptr = reinterpret_cast<const FmhcaParam *>(obj);
return ptr->UnPack(resolver);
}
case OpParameter_AttentionParam: {
auto ptr = reinterpret_cast<const AttentionParam *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -6702,6 +6819,10 @@ inline flatbuffers::Offset<void> OpParameterUnion::Pack(flatbuffers::FlatBufferB
auto ptr = reinterpret_cast<const FmhcaParamT *>(value);
return CreateFmhcaParam(_fbb, ptr, _rehasher).Union();
}
case OpParameter_AttentionParam: {
auto ptr = reinterpret_cast<const AttentionParamT *>(value);
return CreateAttentionParam(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -7096,6 +7217,10 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS
value = new FmhcaParamT(*reinterpret_cast<FmhcaParamT *>(u.value));
break;
}
case OpParameter_AttentionParam: {
value = new AttentionParamT(*reinterpret_cast<AttentionParamT *>(u.value));
break;
}
default:
break;
}
@ -7588,6 +7713,11 @@ inline void OpParameterUnion::Reset() {
delete ptr;
break;
}
case OpParameter_AttentionParam: {
auto ptr = reinterpret_cast<AttentionParamT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;
@ -7776,12 +7906,13 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 }
};
static const flatbuffers::TypeFunction type_refs[] = {
OpTypeTypeTable
};
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 300, 301, 302, 303, 304, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
static const char * const names[] = {
"AbsVal",
"QuantizedAdd",
@ -7948,6 +8079,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"BatchNorm",
"ConvTranspose3D",
"ZeroGrad",
"Attention",
"FmhaV2",
"Fmhca",
"SeqLen2Spatial",
@ -7966,7 +8098,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"GridSample"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_ENUM, 181, type_codes, type_refs, values, names
flatbuffers::ST_ENUM, 182, type_codes, type_refs, values, names
};
return &tt;
}
@ -8070,7 +8202,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
{ flatbuffers::ET_SEQUENCE, 0, 93 },
{ flatbuffers::ET_SEQUENCE, 0, 94 },
{ flatbuffers::ET_SEQUENCE, 0, 95 },
{ flatbuffers::ET_SEQUENCE, 0, 96 }
{ flatbuffers::ET_SEQUENCE, 0, 96 },
{ flatbuffers::ET_SEQUENCE, 0, 97 }
};
static const flatbuffers::TypeFunction type_refs[] = {
QuantizedAddTypeTable,
@ -8169,7 +8302,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
CumSumTypeTable,
GroupNormTypeTable,
FmhaV2ParamTypeTable,
FmhcaParamTypeTable
FmhcaParamTypeTable,
AttentionParamTypeTable
};
static const char * const names[] = {
"NONE",
@ -8269,10 +8403,11 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
"CumSum",
"GroupNorm",
"FmhaV2Param",
"FmhcaParam"
"FmhcaParam",
"AttentionParam"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_UNION, 98, type_codes, type_refs, nullptr, names
flatbuffers::ST_UNION, 99, type_codes, type_refs, nullptr, names
};
return &tt;
}
@ -8376,6 +8511,19 @@ inline const flatbuffers::TypeTable *StringVecTypeTable() {
return &tt;
}
inline const flatbuffers::TypeTable *AttentionParamTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_BOOL, 0, -1 }
};
static const char * const names[] = {
"kv_cache"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 1, type_codes, nullptr, nullptr, names
};
return &tt;
}
inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_INT, 0, -1 }

View File

@ -187,6 +187,7 @@ enum OpType : int {
ZeroGrad,
// User define op
Attention = 299,
FmhaV2 = 300,
Fmhca = 301,
SeqLen2Spatial = 302,
@ -218,7 +219,7 @@ table Extra {
engine: string;
info: [byte];
attr:[Attribute];
// The Extra Op can be vectorized for execution
// The Extra Op can be vectorized for execution
vector: bool;
}
@ -226,10 +227,14 @@ table StringVec {
data: [string];
}
table AttentionParam {
kv_cache: bool = true;
}
table FmhaV2Param {
heads: int;
}
table FmhcaParam {
heads: int;
}
@ -408,7 +413,8 @@ union OpParameter {
CumSum,
GroupNorm,
FmhaV2Param,
FmhcaParam
FmhcaParam,
AttentionParam
}
table Op {

View File

@ -53,7 +53,7 @@ asm_function MNNGemmHybridInt4FP16_sdot
// int32_t useInt8;
//};
//void MNNGemmHybridInt4_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
//void MNNGemmHybridInt4_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
@ -79,8 +79,6 @@ lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_
ld1 {v6.16b, v7.16b}, [x14]
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
TILE_4:
cmp x6, #4
blt TILE_1
@ -120,22 +118,14 @@ LoopSz_TILE_4:
// src : 2 x [2 x 8] : v4-5
// weight : 4 x [2 x 8] : v0-3
// dst : 2 x 4 x [4] : v16-23
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v4.16b, v5.16b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v9.16b, v0.16b, #4
ushr v12.16b, v1.16b, #4
and v8.16b, v0.16b, v14.16b
and v13.16b, v1.16b, v14.16b
sub v9.16b, v9.16b, v15.16b
sub v8.16b, v8.16b, v15.16b
sub v12.16b, v12.16b, v15.16b
sub v13.16b, v13.16b, v15.16b
zip1 v0.16b, v9.16b, v8.16b
zip2 v1.16b, v9.16b, v8.16b
zip1 v2.16b, v12.16b, v13.16b
zip2 v3.16b, v12.16b, v13.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v14.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v14.16b
mov v10.d[0], v4.d[1]
mov v10.d[1], v4.d[0]
mov v11.d[1], v5.d[0]
@ -244,21 +234,13 @@ LoopSz_TILE_1:
// src : 1 x [1 x 8] : v4
// weight : 4 x [2 x 8] : v0-3
// dst : 1 x 4 x [2] : v16-v19
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v4.8b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v16.16b, v0.16b, #4
ushr v17.16b, v1.16b, #4
and v18.16b, v0.16b, v14.16b
and v19.16b, v1.16b, v14.16b
sub v16.16b, v16.16b, v15.16b
sub v18.16b, v18.16b, v15.16b
sub v17.16b, v17.16b, v15.16b
sub v19.16b, v19.16b, v15.16b
zip1 v0.16b, v16.16b, v18.16b
zip2 v1.16b, v16.16b, v18.16b
zip1 v2.16b, v17.16b, v19.16b
zip2 v3.16b, v17.16b, v19.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v14.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v14.16b
mov v29.d[0], v4.d[1]
mov v29.d[1], v4.d[0]

View File

@ -64,7 +64,7 @@ asm_function MNNGemmHybridInt4FP16_smmla
// int32_t useInt8;
//};
//void MNNGemmHybridInt4_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
//void MNNGemmHybridInt4_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
@ -135,22 +135,14 @@ LoopDz_TILE_8:
LoopSz_TILE_8:
// src : 2 x [2 x 8] : v4-5
// weight : 4 x [2 x 8] : v0-3
// dst : 2 x 4 x [4] : v16-23
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
// dst : 2 x 4 x [4] : v16-31
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v10.16b
sub v8.16b, v8.16b, v11.16b
sub v9.16b, v9.16b, v11.16b
ushr v12.16b, v1.16b, #4
and v13.16b, v1.16b, v10.16b
sub v12.16b, v12.16b, v11.16b
sub v13.16b, v13.16b, v11.16b
zip1 v0.16b, v8.16b, v9.16b
zip2 v1.16b, v8.16b, v9.16b
zip1 v2.16b, v12.16b, v13.16b
zip2 v3.16b, v12.16b, v13.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v10.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v10.16b
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3
@ -270,21 +262,13 @@ LoopSz_TILE_4:
// src : 2 x [2 x 8] : v4-5
// weight : 4 x [2 x 8] : v0-3
// dst : 2 x 4 x [4] : v16-23
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v4.16b, v5.16b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v10.16b
sub v8.16b, v8.16b, v11.16b
sub v9.16b, v9.16b, v11.16b
ushr v12.16b, v1.16b, #4
and v13.16b, v1.16b, v10.16b
sub v12.16b, v12.16b, v11.16b
sub v13.16b, v13.16b, v11.16b
zip1 v0.16b, v8.16b, v9.16b
zip2 v1.16b, v8.16b, v9.16b
zip1 v2.16b, v12.16b, v13.16b
zip2 v3.16b, v12.16b, v13.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v10.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v10.16b
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
@ -313,7 +297,7 @@ LoopSzEnd_TILE_4:
trn2 v5.2d, v8.2d, v8.2d // alpha: oc 2,3,2,3
trn1 v6.2d, v9.2d, v9.2d // alpha: oc 4,5,4,5
trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7
MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7
MulScale_New v20, v21, v22, v23, v1, v4, v5, v6, v7
Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2)
@ -372,21 +356,13 @@ LoopSz_TILE_2:
// src : 1 x [2 x 8] : v4
// weight : 4 x [2 x 8] : v0-3
// dst : 1 x 4 x [4] : v16-19
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v4.16b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v10.16b
sub v8.16b, v8.16b, v11.16b
sub v9.16b, v9.16b, v11.16b
ushr v12.16b, v1.16b, #4
and v13.16b, v1.16b, v10.16b
sub v12.16b, v12.16b, v11.16b
sub v13.16b, v13.16b, v11.16b
zip1 v0.16b, v8.16b, v9.16b
zip2 v1.16b, v8.16b, v9.16b
zip1 v2.16b, v12.16b, v13.16b
zip2 v3.16b, v12.16b, v13.16b
// int4 to int8: v0, v1, v2, v3
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v10.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v10.16b
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
@ -409,10 +385,10 @@ LoopSzEnd_TILE_2:
trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7
MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7
Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2)
uzp1 v4.4s, v0.4s, v1.4s
uzp2 v5.4s, v0.4s, v1.4s
Tile2Dequant:
ld1 {v16.8h}, [x20], #16 // zero
ld1 {v17.8h}, [x21], #16 // bias
@ -463,22 +439,14 @@ LoopSz_TILE_1:
// dst : 1 x 4 x [2] : v16-v19
prfm pldl1keep, [x25, #64] //
prfm pldl1keep, [x24, x15] //
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v4.8b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v10.16b
sub v8.16b, v8.16b, v11.16b
sub v9.16b, v9.16b, v11.16b
ushr v12.16b, v1.16b, #4
and v13.16b, v1.16b, v10.16b
sub v12.16b, v12.16b, v11.16b
sub v13.16b, v13.16b, v11.16b
zip1 v0.16b, v8.16b, v9.16b
zip2 v1.16b, v8.16b, v9.16b
zip1 v2.16b, v12.16b, v13.16b
zip2 v3.16b, v12.16b, v13.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v10.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v10.16b
.inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
.inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b
.inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b
@ -505,7 +473,7 @@ LoopSzEnd_TILE_1:
fcvtn v17.4h, v20.4s
fcvtn2 v17.8h, v21.4s
Tile1Dequant:
ld1 {v1.8h}, [x20], #16 // zero
ld1 {v2.8h}, [x21], #16 // bias
ld1 {v3.h}[0], [x22] // sums
@ -535,4 +503,4 @@ ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 9)
ret
#endif
#endif

View File

@ -0,0 +1,462 @@
//
// CPUAttention.cpp
// MNN
//
// Created by MNN on 2024/03/19.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
#include <limits>
#include "CPUAttention.hpp"
#include "CPUBackend.hpp"
#include "compute/CommonOptFunction.h"
#include "core/Macro.h"
#include "core/Concurrency.h"
#include "core/BufferAllocator.hpp"
#include "core/TensorUtils.hpp"
#include "core/OpCommonUtils.hpp"
#if defined (__aarch64__)
#define FLOAT16_T __fp16
#else
#define FLOAT16_T float
#endif
namespace MNN {
template <typename T>
static void prefill_pack(Tensor* query, Tensor* key, Tensor* value, char* query_ptr, char* key_ptr, char* value_ptr,
int mMaxLength, int mNumHead, int mHeadDim, int mValueH,
int eP, int hP, int query_e, int key_h, int seq_len, int h) {
auto query_src = query->host<T>();
auto key_src = key->host<T>();
auto value_src = value->host<T>();
auto query_dst = reinterpret_cast<T*>(query_ptr);
auto key_dst = reinterpret_cast<T*>(key_ptr);
auto value_dst = reinterpret_cast<T*>(value_ptr);
// transpose query: [seq_len, num_head, head_dim] -> numhead, [seq_len/eP, head_dim, eP]
for (int i = 0; i < query_e; i++) {
for (int j = 0; j < mHeadDim; j++) {
for (int k = 0; k < eP; k++) {
int s = i * eP + k;
if (s < seq_len) {
query_dst[i * mHeadDim * eP + j * eP + k] = query_src[s * mNumHead * mHeadDim + h * mHeadDim + j];
}
}
}
}
// transpose key: [seq_len, num_head, head_dim] -> numhead, [seq_len/hP, head_dim, hP]
for (int i = 0; i < key_h; i++) {
for (int j = 0; j < mHeadDim; j++) {
for (int k = 0; k < hP; k++) {
int s = i * hP + k;
if (s < seq_len) {
key_dst[i * mHeadDim * hP + j * hP + k] = key_src[s * mNumHead * mHeadDim + h * mHeadDim + j];
}
}
}
}
// transpose value: [seq_len, num_head, head_dim] -> numhead, [head_dim/hP, seq_len, hP]
for (int i = 0; i < mValueH; i++) {
for (int j = 0; j < seq_len; j++) {
for (int k = 0; k < hP; k++) {
int hd = i * hP + k;
if (hd < mHeadDim) {
value_dst[i * mMaxLength * hP + j * hP + k] = value_src[j * mNumHead * mHeadDim + h * mHeadDim + hd];
}
}
}
}
}
template <typename T>
static void decode_pack(Tensor* query, Tensor* key, Tensor* value, char* query_ptr, char* key_ptr, char* value_ptr,
int mMaxLength, int mPastLength, int mHeadDim, int mValueH, int eP, int hP, int h) {
auto query_src = query->host<T>();
auto key_src = key->host<T>();
auto value_src = value->host<T>();
auto query_dst = reinterpret_cast<T*>(query_ptr);
auto key_dst = reinterpret_cast<T*>(key_ptr);
auto value_dst = reinterpret_cast<T*>(value_ptr);
for (int i = 0; i < mHeadDim; i++) {
query_dst[i * eP] = query_src[h * mHeadDim + i];
}
// transpose key: [1, num_head, head_dim] -> numhead, [kv_seq_len/hP, head_dim, hP]
int outside_offset = UP_DIV(mPastLength, hP);
int inside_offset = mPastLength % hP;
for (int i = 0; i < mHeadDim; i++) {
key_dst[(outside_offset - (inside_offset != 0)) * mHeadDim * hP + i * hP + inside_offset] = key_src[h * mHeadDim + i];
}
// transpose value: [1, num_head, head_dim] -> numhead, [head_dim/hP, kv_seq_len, hP]
for (int i = 0; i < mValueH; i++) {
for (int j = 0; j < hP; j++) {
value_dst[i * mMaxLength * hP + mPastLength * hP + j] = value_src[h * mHeadDim + i * hP + j];
}
}
}
template <typename T>
static void prefill_unpack(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) {
auto src_ptr = reinterpret_cast<T*>(pack_qkv);
auto dst_ptr = reinterpret_cast<T*>(unpack_qkv);
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < mHeadDim; j++) {
int a = j / unit;
int b = j % unit;
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b];
}
}
}
template <typename T>
static void prefill_softmax(int* mask_ptr, float* mask_qk, float* softmax_qk, char* unpack_qk, char* pack_qk,
float mScale, int eP, int query_e, int seq_len, float min_val) {
T* qk_src = reinterpret_cast<T*>(unpack_qk);
T* qk_dst = reinterpret_cast<T*>(pack_qk);
for (int i = 0; i < seq_len * seq_len; i++) {
if (mask_ptr[i]) {
mask_qk[i] = qk_src[i] * mScale;
} else {
mask_qk[i] = min_val;
}
}
for (int i = 0; i < seq_len; i++) {
MNNSoftmax(softmax_qk + i * seq_len, mask_qk + i * seq_len, seq_len);
}
for (int i = 0; i < query_e; i++) {
for (int j = 0; j < seq_len; j++) {
for (int k = 0; k < eP; k++) {
int s = i * eP + k;
if (s < seq_len) {
qk_dst[i * seq_len * eP + j * eP + k] = softmax_qk[s * seq_len + j];
}
}
}
}
}
template <typename T>
static void decode_softmax(float* mask_qk, float* softmax_qk, char* unpack_qk, char* pack_qk,
float mScale, int eP, int kv_seq_len) {
T* qk_src = reinterpret_cast<T*>(unpack_qk);
T* qk_dst = reinterpret_cast<T*>(pack_qk);
for (int i = 0; i < kv_seq_len; i++) {
mask_qk[i] = qk_src[i] * mScale;
}
// softmax
MNNSoftmax(softmax_qk, mask_qk, kv_seq_len);
// pack qk
for (int i = 0; i < kv_seq_len; i++) {
qk_dst[i * eP] = softmax_qk[i];
}
}
void CPUAttentionImpl::allocKVCache() {
if (!mKVCache || mPastLength < mMaxLength) {
return;
}
mMaxLength = mPastLength + mExpandChunk;
// past_key: [1, numhead, headdim, maxlen] -> numhead, [headdim, maxlen] -> pack_b -> numhead, [maxlen/hP, head_dim, hP]
mPastKey.reset(Tensor::createDevice<float>({mNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
// past_value: [1, numhead, maxlen, headdim] -> numhead, [maxlen, headdim] -> pack_b -> numhead, [head_dim/hP, max_len, hP]
mPastValue.reset(Tensor::createDevice<float>({mNumHead, mValueH, mMaxLength, hP}));
backend()->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
backend()->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
}
void CPUAttentionImpl::reallocKVCache() {
if (!mKVCache || mPastLength < mMaxLength) {
return;
}
mMaxLength = mPastLength + mExpandChunk;
// past_key: [1, numhead, headdim, maxlen] -> numhead, [headdim, maxlen] -> pack_b -> numhead, [maxlen/hP, head_dim, hP]
auto new_key = Tensor::createDevice<float>({mNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
// past_value: [1, numhead, maxlen, headdim] -> numhead, [maxlen, headdim] -> pack_b -> numhead, [head_dim/hP, max_len, hP]
auto new_value = Tensor::createDevice<float>({mNumHead, mValueH, mMaxLength, hP});
backend()->onAcquireBuffer(new_key, Backend::STATIC);
backend()->onAcquireBuffer(new_value, Backend::STATIC);
// copy
for (int h = 0; h < mNumHead; h++) {
::memset(new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes, 0, UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes);
::memset(new_value->host<char>() + h * mValueH * mMaxLength * hP * bytes, 0, mValueH * mMaxLength * hP * bytes);
::memcpy(new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes,
mPastKey->host<char>() + h * UP_DIV(mPastLength, hP) * mHeadDim * hP * bytes,
UP_DIV(mPastLength, hP) * mHeadDim * hP * bytes);
for (int i = 0; i < mValueH; i++) {
::memcpy(new_value->host<char>() + (h * mValueH + i) * mMaxLength * hP * bytes,
mPastValue->host<char>() + (h * mValueH + i) * mPastLength * hP * bytes,
mPastLength * hP * bytes);
}
}
mPastKey.reset(new_key);
mPastValue.reset(new_value);
mTempQK.reset(Tensor::createDevice<float>({mThreadNum, eP + 2, mMaxLength}));
backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC);
}
ErrorCode CPUAttentionImpl::onResize(Backend* _backend, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
mBackend = _backend;
auto core = static_cast<CPUBackend *>(backend())->functions();
int unit = core->pack;
bytes = core->bytes;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto mask = inputs[3];
auto shape = query->shape();
int seq_len = shape[1];
mThreadNum = ((CPUBackend *)backend())->threadNumber();
mIsDecode = seq_len == 1;
if (mPastLength == 0 || seq_len > 1) {
mPastLength = seq_len;
}
mNumHead = shape[2];
mHeadDim = shape[3];
mScale = 1.0 / sqrt(mHeadDim);
mValueH = UP_DIV(mHeadDim, hP);
int query_e = UP_DIV(seq_len, eP);
int key_h = UP_DIV(seq_len, hP);
// mPastLength = 10;
// alloc kv cache
allocKVCache();
int tileCount = UP_DIV(mNumHead, mThreadNum);
// temp_query
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, query_e, mHeadDim, eP}));
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
if (mIsDecode) {
mTempQK.reset(Tensor::createDevice<float>({mThreadNum, eP + 2, mMaxLength}));
backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
} else {
mTempQK.reset(Tensor::createDevice<float>({mThreadNum, 4, seq_len, seq_len}));
backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
}
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
return NO_ERROR;
}
ErrorCode CPUAttentionImpl::onExecute(Backend* _backend, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
auto core = static_cast<CPUBackend *>(backend())->functions();
int unit = core->pack;
bytes = core->bytes;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
mBackend = _backend;
auto matmulUnit = core->MNNPackedMatMul;
auto matmulRemain = core->MNNPackedMatMulRemain;
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto mask = inputs[3];
auto shape = query->shape();
int seq_len = shape[1];
mThreadNum = ((CPUBackend *)backend())->threadNumber();
mIsDecode = seq_len == 1;
if (mPastLength == 0 || seq_len > 1) {
mPastLength = seq_len;
}
mNumHead = shape[2];
mHeadDim = shape[3];
mScale = 1.0 / sqrt(mHeadDim);
mValueH = UP_DIV(mHeadDim, hP);
int query_e = UP_DIV(seq_len, eP);
int key_h = UP_DIV(seq_len, hP);
// mPastLength = 10;
int tileCount = UP_DIV(mNumHead, mThreadNum);
// try calloc kv cache
mPrefill = [=](int tId){
auto pack_q = mPackQ->host<char>() + tId * query_e * mHeadDim * eP * bytes;
auto pack_qk = mTempQK->host<char>() + tId * 4 * seq_len * seq_len * bytes;
auto unpack_qk = pack_qk + seq_len * seq_len * 2 * bytes;
auto mask_qk = reinterpret_cast<float*>(pack_qk);
auto softmax_qk = reinterpret_cast<float*>(unpack_qk);
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
int head_index = tId * tileCount;
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
// pack for matmul
auto key_dst = mPastKey->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes;
auto value_dst = mPastValue->host<char>() + h * mValueH * mMaxLength * hP * bytes;
if (bytes == 2) {
prefill_pack<int16_t>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mNumHead, mHeadDim, mValueH, eP, hP, query_e, key_h, seq_len, h);
} else {
prefill_pack<float>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mNumHead, mHeadDim, mValueH, eP, hP, query_e, key_h, seq_len, h);
}
// query @ key
int loop_e = seq_len / eP;
int remain = seq_len % eP;
for (int i = 0 ; i < loop_e; i++) {
size_t shapeParameters[6];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = mHeadDim;
parameters[2] = seq_len;
parameters[3] = seq_len * unit * bytes;
parameters[4] = 0;
parameters[5] = 0;
matmulUnit((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_dst, parameters, nullptr, nullptr, nullptr, nullptr);
}
{
size_t shapeParameters[6];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = mHeadDim;
parameters[2] = seq_len;
parameters[3] = seq_len * unit * bytes;
parameters[4] = 0;
parameters[5] = 0;
matmulRemain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_dst, remain, parameters, nullptr, nullptr, nullptr, nullptr);
}
int area_offset[1] {seq_len};
core->MNNUnpackCUnitTranspose((float*)unpack_qk, (float*)pack_qk, seq_len, seq_len, area_offset);
// div scale and mask
auto mask_ptr = mask->host<int>();
if (bytes == 2) {
prefill_softmax<FLOAT16_T>(mask_ptr, mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, query_e, seq_len, -65504.0);
} else {
prefill_softmax<float>(mask_ptr, mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, query_e, seq_len, std::numeric_limits<float>::lowest());
}
// qk @ v
for (int i = 0 ; i < loop_e; i++) {
size_t shapeParameters[6];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = seq_len;
parameters[2] = mHeadDim;
parameters[3] = seq_len * unit * bytes;
parameters[4] = 0;
parameters[5] = (mMaxLength - seq_len) * hP * bytes;
matmulUnit((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(pack_qk + (i * seq_len * eP) * bytes), (float*)value_dst, parameters, nullptr, nullptr, nullptr, nullptr);
}
{
size_t shapeParameters[6];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = seq_len;
parameters[2] = mHeadDim;
parameters[3] = seq_len * unit * bytes;
parameters[4] = 0;
parameters[5] = (mMaxLength - seq_len) * hP * bytes;
matmulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(pack_qk + (loop_e * seq_len * eP) * bytes), (float*)value_dst, remain, parameters, nullptr, nullptr, nullptr, nullptr);
}
// transpose: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
if (bytes == 2) {
prefill_unpack<int16_t>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
} else {
prefill_unpack<float>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
}
}
};
mDecode = [=](int tId) {
int kv_seq_len = mPastLength + 1;
auto pack_q = mPackQ->host<char>() + tId * mHeadDim * eP * bytes;
auto pack_qk = mTempQK->host<char>() + tId * (eP + 2) * kv_seq_len * bytes;
auto unpack_qk = pack_qk + kv_seq_len * eP * bytes;
auto mask_qk = reinterpret_cast<float*>(pack_qk);
auto softmax_qk = reinterpret_cast<float*>(unpack_qk);
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * unit * bytes;
int head_index = tId * tileCount;
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
auto key_dst = mPastKey->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * bytes;
auto value_dst = mPastValue->host<char>() + h * mValueH * mMaxLength * hP * bytes;
// pack for matmul
if (bytes == 2) {
decode_pack<int16_t>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mPastLength, mHeadDim, mValueH, eP, hP, h);
} else {
decode_pack<float>(query, key, value, pack_q, key_dst, value_dst, mMaxLength, mPastLength, mHeadDim, mValueH, eP, hP, h);
}
// query @ key: [1, head_dim] @ [head_dim, kv_seq_len] -> [1, kv_seq_len]
size_t shapeParameters[6];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = mHeadDim;
parameters[2] = kv_seq_len;
parameters[3] = seq_len * unit * bytes;
parameters[4] = 0;
parameters[5] = 0;
matmulRemain((float*)pack_qk, (float*)pack_q, (float*)key_dst, seq_len, parameters, nullptr, nullptr, nullptr, nullptr);
int area_offset[1] {seq_len};
core->MNNUnpackCUnitTranspose((float*)unpack_qk, (float*)pack_qk, seq_len, kv_seq_len, area_offset);
if (bytes == 2) {
decode_softmax<FLOAT16_T>(mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, kv_seq_len);
} else {
decode_softmax<float>(mask_qk, softmax_qk, unpack_qk, pack_qk, mScale, eP, kv_seq_len);
}
// qk @ v: [1, kv_seq_len] @ [kv_seq_len, head_dim] -> [1, head_dim]
{
size_t shapeParameters[6];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = kv_seq_len;
parameters[2] = mHeadDim;
parameters[3] = 1 * unit * bytes;
parameters[5] = (mMaxLength - kv_seq_len) * hP * bytes;
matmulRemain((float*)pack_qkv, (float*)pack_qk, (float*)value_dst, 1, parameters, nullptr, nullptr, nullptr, nullptr);
}
// transpose: [head_dim/unit, 1, unit] -> [1, num_head, head_dim]
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
core->MNNUnpackCUnitTranspose((float*)dst_ptr, (float*)pack_qkv, 1, mHeadDim, area_offset);
}
};
mFunction = mIsDecode ? mDecode : mPrefill;
reallocKVCache();
// compute
MNN_CONCURRENCY_BEGIN(tId, mThreadNum) {
mFunction((int)tId);
}
MNN_CONCURRENCY_END();
mPastLength += mIsDecode;
return NO_ERROR;
}
CPUAttention::CPUAttention(Backend* backend, bool kv_cahce) : Execution(backend) {
mImpl.reset(new CPUAttentionImpl(backend, kv_cahce));
}
CPUAttention::CPUAttention(std::shared_ptr<CPUAttentionImpl> impl, Backend *backend) : Execution(backend), mImpl(impl) {}
ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return mImpl->onResize(backend(), inputs, outputs);
}
ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return mImpl->onExecute(backend(), inputs, outputs);
}
bool CPUAttention::onClone(Backend* bn, const Op* op, Execution** dst) {
if (nullptr == dst) {
return true;
}
*dst = new CPUAttention(mImpl, bn);
return true;
}
class CPUAttentionCreator : public CPUBackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
auto param = op->main_as_AttentionParam();
return new CPUAttention(backend, param->kv_cache());
}
};
REGISTER_CPU_OP_CREATOR_TRANSFORMER(CPUAttentionCreator, OpType_Attention);
} // namespace MNN
#endif

View File

@ -0,0 +1,58 @@
//
// CPUAttention.hpp
// MNN
//
// Created by MNN on 2024/03/19.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
#ifndef CPUATTENTION_HPP
#define CPUATTENTION_HPP
#include <functional>
#include "core/Execution.hpp"
namespace MNN {
class CPUAttentionImpl {
public:
CPUAttentionImpl(Backend *backend, bool kv_cache) : mBackend(backend), mKVCache(kv_cache) {}
~CPUAttentionImpl() = default;
ErrorCode onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
ErrorCode onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
private:
void allocKVCache();
void reallocKVCache();
Backend* backend() { return mBackend; }
private:
Backend* mBackend;
bool mKVCache;
float mScale;
const int mExpandChunk = 64;
int mThreadNum = 1;
bool mIsDecode = false;
int mPastLength = 0, mMaxLength = 0;
std::shared_ptr<Tensor> mPastKey, mPastValue, mTempQK;
std::shared_ptr<Tensor> mPackQ, mPackQKV;
int mNumHead = 0, mHeadDim = 0, mValueH = 0;
int eP, lP, hP, bytes;
std::function<void(int)> mFunction, mPrefill, mDecode;
};
class CPUAttention : public Execution {
public:
CPUAttention(Backend *backend, bool kv_cache);
CPUAttention(std::shared_ptr<CPUAttentionImpl> impl, Backend *backend);
virtual ~CPUAttention() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
private:
std::shared_ptr<CPUAttentionImpl> mImpl;
};
} // namespace MNN
#endif // CPUATTENTION_HPP
#endif

View File

@ -93,7 +93,7 @@ public:
virtual void* onMapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* srcTensor) override;
virtual bool onUnmapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* dstTensor, void* mapPtr) override;
virtual void onResizeBegin() override;
virtual ErrorCode onResizeEnd() override;
@ -194,6 +194,12 @@ private:
CPUBackend::addCreator(opType, &_temp); \
}
#define REGISTER_CPU_OP_CREATOR_TRANSFORMER(name, opType) \
void ___##name##__##opType##__() { \
static name _temp;\
CPUBackend::addCreator(opType, &_temp); \
}
} // namespace MNN
#endif /* CPUBackend_hpp */

View File

@ -105,6 +105,10 @@ ErrorCode CPULayerNorm::onResize(const std::vector<Tensor*> &inputs,
mInnerSize *= inputs.at(0)->length(i);
}
mInnerSize /= mResource->mGroup;
if (mResource->mIniGammaBeta) {
MNN_ASSERT(mResource->mGamma->size() == mInnerSize * sizeof(float));
}
return NO_ERROR;
}
for (int i = 0; i < rank - mResource->mAxis; ++i) {
@ -113,6 +117,9 @@ ErrorCode CPULayerNorm::onResize(const std::vector<Tensor*> &inputs,
for (int i = rank - mResource->mAxis; i < rank; ++i) {
mInnerSize *= inputs.at(0)->length(i);
}
if (mResource->mIniGammaBeta) {
MNN_ASSERT(mResource->mGamma->size() == mInnerSize * sizeof(float));
}
if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) {
mInpZero.resize(1);
mOutZero.resize(1);

View File

@ -75,6 +75,9 @@ extern void ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
extern void ___CPURasterDiffCreator__OpType_RasterDiff__();
extern void ___CPUTextureCreator__OpType_Texture__();
#endif
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
extern void ___CPUAttentionCreator__OpType_Attention__();
#endif
void registerCPUOps() {
___CPUCropAndResizeCreator__OpType_CropAndResize__();
___CPUArgMaxCreator__OpType_ArgMax__();
@ -150,5 +153,8 @@ ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
___CPURasterDiffCreator__OpType_RasterDiff__();
___CPUTextureCreator__OpType_Texture__();
#endif
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
___CPUAttentionCreator__OpType_Attention__();
#endif
}
}

View File

@ -3,7 +3,7 @@
#ifdef MNN_USE_NEON
#include <arm_neon.h>
#include "./FunctionSummary.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
extern "C" {
void MNNTranspose32Bit4x4(int32_t* dstO, const int32_t* srcO, int32_t* dim);

View File

@ -43,7 +43,7 @@ asm_function MNNGemmHybridInt4FP32
// int32_t useInt8;
//};
//void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
//void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
@ -120,8 +120,6 @@ LoopSz_TILE_4:
// int4->int8
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
Unit_TILE_4:
@ -244,8 +242,6 @@ LoopSz_TILE_1:
// int4->int8
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
Unit_TILE_1:
@ -261,7 +257,7 @@ LoopSz_TILE_1:
smlal v11.4s, v5.4h, v29.4h
smlal v12.4s, v5.4h, v30.4h
smlal v13.4s, v5.4h, v31.4h
//.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0]
subs x26, x26, #1

View File

@ -43,7 +43,7 @@ asm_function MNNGemmHybridInt4FP32_sdot
// int32_t useInt8;
//};
//void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
//void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
@ -99,11 +99,9 @@ LoopDz_TILE_12:
movi v25.4s, #0
movi v26.4s, #0
movi v27.4s, #0
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
LoopSz_TILE_12:
// src : 4(batch) x [1 x 4] : v4
// weight : 4(oc) x [1 x 4] : v0
@ -113,8 +111,6 @@ LoopSz_TILE_12:
// int4->int8
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
.inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
@ -203,11 +199,9 @@ LoopDz_TILE_8:
movi v21.4s, #0
movi v22.4s, #0
movi v23.4s, #0
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
LoopSz_TILE_8:
// src : 4(batch) x [1 x 4] : v4
// weight : 4(oc) x [1 x 4] : v0
@ -217,8 +211,6 @@ LoopSz_TILE_8:
// int4->int8
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
.inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
@ -294,8 +286,6 @@ LoopDz_TILE_4:
dup v19.4s, wzr
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
LoopSz_TILE_4:
// src : 4(batch) x [1 x 4] : v4
// weight : 4(oc) x [1 x 4] : v0
@ -305,8 +295,6 @@ LoopSz_TILE_4:
// int4->int8
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0
.inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1
@ -368,8 +356,6 @@ LoopDz_TILE_1:
dup v16.4s, wzr
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
LoopSz_TILE_1:
// src : 1(batch) x [1 x 4] : v4
// weight : 4(oc) x [1 x 4] : v0
@ -379,10 +365,8 @@ LoopSz_TILE_1:
// int4->int8
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0]
subs x26, x26, #1

View File

@ -43,7 +43,7 @@ asm_function MNNGemmHybridInt4FP32_smmla
// int32_t useInt8;
//};
//void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
//void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param);
// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param
@ -106,28 +106,18 @@ LoopDz_TILE_8:
// mask
movi v10.16b, #15
// offset
movi v11.16b, #8
LoopSz_TILE_8:
// src : 2 x [2 x 8] : v4-5
// weight : 4 x [2 x 8] : v0-3
// dst : 2 x 4 x [4] : v16-23
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
ld1 {v12.16b, v13.16b, v14.16b, v15.16b}, [x24], x15 // src
// int4 to int8: v0, v1, v2, v3
ushr v4.16b, v0.16b, #4
and v5.16b, v0.16b, v10.16b
sub v4.16b, v4.16b, v11.16b
sub v5.16b, v5.16b, v11.16b
ushr v8.16b, v1.16b, #4
and v9.16b, v1.16b, v10.16b
sub v8.16b, v8.16b, v11.16b
sub v9.16b, v9.16b, v11.16b
zip1 v0.16b, v4.16b, v5.16b
zip2 v1.16b, v4.16b, v5.16b
zip1 v2.16b, v8.16b, v9.16b
zip2 v3.16b, v8.16b, v9.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v10.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v10.16b
.inst 0x4e80a590 // smmla v16.4s, v12.16b, v0.16b
.inst 0x4e81a591 // smmla v17.4s, v12.16b, v1.16b
@ -247,27 +237,17 @@ LoopDz_TILE_4:
dup v23.4s, wzr
// mask
movi v10.16b, #15
// offset
movi v11.16b, #8
LoopSz_TILE_4:
// src : 2 x [2 x 8] : v4-5
// weight : 4 x [2 x 8] : v0-3
// dst : 2 x 4 x [4] : v16-23
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
// int4 to int8: v0, v1, v2, v3
ushr v4.16b, v0.16b, #4
and v5.16b, v0.16b, v10.16b
sub v4.16b, v4.16b, v11.16b
sub v5.16b, v5.16b, v11.16b
ushr v6.16b, v1.16b, #4
and v7.16b, v1.16b, v10.16b
sub v6.16b, v6.16b, v11.16b
sub v7.16b, v7.16b, v11.16b
zip1 v0.16b, v4.16b, v5.16b
zip2 v1.16b, v4.16b, v5.16b
zip1 v2.16b, v6.16b, v7.16b
zip2 v3.16b, v6.16b, v7.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v10.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v10.16b
ld1 {v4.16b, v5.16b}, [x24], x15 // src
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
@ -349,27 +329,17 @@ LoopDz_TILE_2:
dup v19.4s, wzr
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
LoopSz_TILE_2:
// src : 1 x [2 x 8] : v4
// weight : 4 x [2 x 8] : v0-3
// dst : 1 x 4 x [4] : v16-19
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
// int4 to int8: v0, v1, v2, v3
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
ushr v10.16b, v1.16b, #4
and v11.16b, v1.16b, v14.16b
sub v10.16b, v10.16b, v15.16b
sub v11.16b, v11.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
zip2 v1.16b, v8.16b, v9.16b
zip1 v2.16b, v10.16b, v11.16b
zip2 v3.16b, v10.16b, v11.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v14.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v14.16b
ld1 {v4.16b}, [x24], x15 // src
.inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b
.inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b
@ -438,28 +408,18 @@ LoopDz_TILE_1:
dup v19.4s, wzr
// mask
movi v14.16b, #15
// offset
movi v15.16b, #8
LoopSz_TILE_1:
// src : 1 x [1 x 8] : v4
// weight : 4 x [2 x 8] : v0-3
// dst : 1 x 4 x [2] : v16-v19
//ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight
ld1 {v0.16b, v1.16b}, [x25], #32 // weight
ld1 {v8.16b, v9.16b}, [x25], #32 // weight
// int4 to int8: v0, v1, v2, v3
ushr v8.16b, v0.16b, #4
and v9.16b, v0.16b, v14.16b
sub v8.16b, v8.16b, v15.16b
sub v9.16b, v9.16b, v15.16b
ushr v10.16b, v1.16b, #4
and v11.16b, v1.16b, v14.16b
sub v10.16b, v10.16b, v15.16b
sub v11.16b, v11.16b, v15.16b
zip1 v0.16b, v8.16b, v9.16b
zip2 v1.16b, v8.16b, v9.16b
zip1 v2.16b, v10.16b, v11.16b
zip2 v3.16b, v10.16b, v11.16b
ushr v0.16b, v8.16b, #4
and v1.16b, v8.16b, v14.16b
ushr v2.16b, v9.16b, #4
and v3.16b, v9.16b, v14.16b
ld1 {v4.8b}, [x24], x15 // src
.inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b
.inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b

View File

@ -18,7 +18,7 @@
#include "math/Vec.hpp"
#include <vector>
#include "../CPURuntime.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
// TODO: Find better way to optimize it
#include "../CPUBinary.hpp"
#include "../CPUUnary.hpp"
@ -376,7 +376,7 @@ void MNN1BitCopyFast (uint8_t* dstO, const uint8_t* srcO, int size, int stride,
#ifdef MNN_USE_SSE
std::vector<uint8_t> arr(16, val);
auto val16 = _mm_loadu_ps((float*)arr.data());
for (; cnt >= 16; cnt-=16) {
_mm_storeu_ps((float*)dstO, val16);
dstO += 16;
@ -427,7 +427,7 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
if (size >= 8) {
auto sum4_1 = _mm_set_ps1(0.f);
auto sum4_2 = _mm_set_ps1(0.f);
for (; i < size8; i += 8) {
auto v4 = _mm_loadu_ps(src);
auto u4 = _mm_loadu_ps(src + 4);
@ -435,7 +435,7 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
sum4_2 = _mm_add_ps(sum4_2, u4);
src += 8;
}
sum4_1 = _mm_add_ps(sum4_1, sum4_2);
_mm_storeu_ps(tmp, sum4_1);
sum += (tmp[0] + tmp[1] + tmp[2] + tmp[3]);
@ -823,7 +823,7 @@ void MNNGemmHybridInt8FP32(float* C, const int8_t* A, const int8_t* B, size_t sr
}
}
}
// int32->float
for (int cn = 0; cn < pack; ++cn) {
float val = (float)tmp[cn] * scale[0];
@ -873,7 +873,7 @@ void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t sr
}
}
}
// int32->float
for (int cn = 0; cn < pack; ++cn) {
float val = (float)tmp[cn] * scale[0];
@ -3333,7 +3333,7 @@ void MNNCoreFunctionInit() {
gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg<float, Vec4, 4>);
// Set min value as 1 << 24
gCoreFunction->MNNPoolingMax = (decltype(gCoreFunction->MNNPoolingMax))(poolingMax<float, Vec4, 4, -16777216>);
gCoreFunction->MNNPoolingMaxWithRedice = (decltype(gCoreFunction->MNNPoolingMaxWithRedice))(poolingMaxWithRedice<float, -16777216>);
// ImageProcess Functions
gCoreFunction->MNNRGBAToBGRA = MNNRGBAToBGRA;
@ -3353,7 +3353,7 @@ void MNNCoreFunctionInit() {
gCoreFunction->MNN4BitcopyFast = MNN4BitcopyFast;
gCoreFunction->MNN2BitcopyFast = MNN2BitcopyFast;
gCoreFunction->MNN1BitcopyFast = MNN1BitCopyFast;
gCoreFunction->MNNAccumulateSequenceNumber = MNNAccumulateSequenceNumber;
cpuinfo_arm_isa gCPUInfo;

View File

@ -33,19 +33,7 @@ bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr<ConvolutionCommon::
resource->lU = lU;
resource->hP = hP;
resource->lP = lP;
// Reorder weight
auto dstWInt8 = resource->mWeight->host<int8_t>();
auto srcWInt8 = int8Info->weight.get();
// oc, ic -> oc/hP, ic/lP, hP, lP
for (int i = 0; i < hU; i++) {
for (int j = 0; j < lU; j++) {
for (int k = 0; k < hP; k++) {
for (int l = 0; l < lP; l++) {
dstWInt8[i * srcChannel * hP + j * hP * lP + k * lP + l] = srcWInt8[(i * hP + k) * srcChannel + (j * lP + l)];
}
}
}
}
// Save scale bias
resource->mDequantize.mScaleBias.reset(MNN::Tensor::createDevice<float>({hU * hP * 2}));
res = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC);
@ -56,6 +44,12 @@ bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr<ConvolutionCommon::
auto biasPtr = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(alphaPtr) + hU * hP * bytes);
::memset(alphaPtr, 0, 2 * hU * hP * bytes);
int h = int8Info->alpha.size();
if (int8Info->canUseInt4) {
// int4 to uint4, -8 offset merge to bias
for (int i = 0; i < h/2; ++i) {
int8Info->alpha.get()[2 * i] -= 8 * int8Info->alpha.get()[2 * i + 1];
}
}
if (bytes == 2) {
auto core = static_cast<CPUBackend*>(resource->backend)->functions();
if (int8Info->asymmetric) {
@ -86,21 +80,53 @@ bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr<ConvolutionCommon::
MNN_ASSERT(weightLength % 2 == 0);
weightLength = UP_DIV(weightLength, 2);
resource->mDequantize.bits = 4;
std::shared_ptr<MNN::Tensor> weightLow(Tensor::createDevice<uint8_t>(
{weightLength}));
auto res = resource->backend->onAcquireBuffer(weightLow.get(), Backend::STATIC);
if (!res) {
return false;
auto srcPtr = int8Info->weight.get();
auto dstPtr = resource->mWeight->host<uint8_t>();
// oc, ic -> oc/hP, ic/lP, hP, lP
if (hP == 8 && lP == 8) {
for (int i = 0; i < hU; i++) {
for (int j = 0; j < lU; j++) {
for (int k = 0; k < 2; k++) {
for (int n = 0; n < 16; n++) {
int hp_idx = n / 8;
int lp_idx = n % 8;
int s0 = srcPtr[(i * hP + k * 4 + hp_idx) * srcChannel + (j * lP + lp_idx)];
int s1 = srcPtr[(i * hP + k * 4 + hp_idx + 2) * srcChannel + (j * lP + lp_idx)];
int d = (s0 + 8) * 16 + (s1 + 8);
dstPtr[(i * srcChannel * hP + j * hP * lP + k * 32) / 2 + n] = (uint8_t)d;
}
}
}
}
} else {
for (int i = 0; i < hU; i++) {
for (int j = 0; j < lU; j++) {
for (int k = 0; k < hP; k++) {
for (int l = 0; l < lP; l+=2) {
int s0 = srcPtr[(i * hP + k) * srcChannel + (j * lP + l)];
int s1 = srcPtr[(i * hP + k) * srcChannel + (j * lP + l + 1)];
int d = (s0 + 8) * 16 + (s1 + 8);
dstPtr[(i * srcChannel * hP + j * hP * lP + k * lP + l) / 2] = d;
}
}
}
}
}
auto srcPtr = resource->mWeight->host<int8_t>();
auto dstPtr = weightLow->host<uint8_t>();
for (int i=0; i < weightLength; ++i) {
int s0 = srcPtr[2 * i + 0];
int s1 = srcPtr[2 * i + 1];
int d = (s0 + 8) * 16 + (s1 + 8);
dstPtr[i] = d;
} else {
// Reorder weight for int8
auto dstWInt8 = resource->mWeight->host<int8_t>();
auto srcWInt8 = int8Info->weight.get();
// oc, ic -> oc/hP, ic/lP, hP, lP
for (int i = 0; i < hU; i++) {
for (int j = 0; j < lU; j++) {
for (int k = 0; k < hP; k++) {
for (int l = 0; l < lP; l++) {
dstWInt8[i * srcChannel * hP + j * hP * lP + k * lP + l] = srcWInt8[(i * hP + k) * srcChannel + (j * lP + l)];
}
}
}
}
resource->mWeight = weightLow;
}
return true;
}
@ -129,10 +155,10 @@ ConvolutionHybrid::ConvolutionHybrid(const Convolution2DCommon *common, Backend
hPack = unit;
lPack = unit;
// [oc, ic] => [oc/unit, ic/src_unit, unit, src_unit]
if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla.
hPack = 8;
lPack = 8;
}
if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla.
hPack = 8;
lPack = 8;
}
auto hU = UP_DIV(outputCount, hPack);
auto lU = UP_DIV(inputCount, lPack);
ConvolutionHybrid::initQuantizeResource(quantInfo, mResource, hU, hPack, lU, lPack, outputCount, (int)originWeightSize / (int)biasSize, common->kernelX() * common->kernelY(), core->bytes);
@ -237,8 +263,8 @@ ErrorCode ConvolutionHybrid::onResize(const std::vector<Tensor *> &inputs, const
int iTileCount = UP_DIV(iC4, threadNumber);
if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla.
ANeedToPack8 = true;
}
int8_t order[32] = {0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 28, 29, 30, 31, 8, 9, 10, 11, 4, 5, 6, 7, 24, 25, 26, 27, 20, 21, 22, 23};
}
int8_t order[32] = {0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 28, 29, 30, 31, 8, 9, 10, 11, 4, 5, 6, 7, 24, 25, 26, 27, 20, 21, 22, 23};
allocDynamicQuantInfo(threadNumber, batch, ic, oc, bytes);
mDynamicQuant = [=]() {
auto maxPtr = mQuantInfo.quant_info.host<uint8_t>();

View File

@ -15,7 +15,7 @@
#include "core/TensorUtils.hpp"
#include "math/WingoradGenerater.hpp"
#include <MNN/AutoTime.hpp>
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
#ifdef MNN_USE_NEON
#include <arm_neon.h>
#endif

View File

@ -15,7 +15,7 @@
#include "core/TensorUtils.hpp"
#include "math/WingoradGenerater.hpp"
#include <MNN/AutoTime.hpp>
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
constexpr int FULSE_THRESHHOLD_NUMERATOR = 10;
constexpr int FULSE_THRESHHOLD_DENOMINATOR = 10;

View File

@ -16,7 +16,7 @@
#include "core/TensorUtils.hpp"
#include "math/Vec.hpp"
#include "core/BufferAllocator.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
using Vec4 = MNN::Math::Vec<float, 4>;
namespace MNN {

View File

@ -15,7 +15,7 @@
#include "core/TensorUtils.hpp"
#include "math/WingoradGenerater.hpp"
#include <MNN/AutoTime.hpp>
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
//#define MNN_WINOGRAD_PRINT_REDUCE_RATE

View File

@ -16,7 +16,7 @@
#include "core/TensorUtils.hpp"
#include "math/Vec.hpp"
#include "core/BufferAllocator.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
#define PARAMETERSIZE 6
using Vec4 = MNN::Math::Vec<float, 4>;

View File

@ -10,7 +10,7 @@
#include <cstring> // for memset
#include "Int8FunctionsOpt.h"
#include "core/Macro.h"
#include "common/CommonCompute.hpp"
#include "core/CommonCompute.hpp"
#include "CommonOptFunction.h"
#include "math/Vec.hpp"
@ -1537,7 +1537,7 @@ static void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src,
}
for (int i = 0; i < pack; ++i) {
float val = (dstInt32[i] + bias_z[i]) * scale_z[i];
int valOut = roundf(val) + offset;
if (valOut > parameters->maxValue + offset) {
@ -1615,7 +1615,7 @@ void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWi
for (int y = 0; y < kernely; ++y) {
for (int x = 0; x < kernelx; ++x) {
const int8_t* inputPtr = srcPtr + pack* (x + inputWidth* y);
for (int idx = 0; idx < pack; ++idx) {
for (int idx = 0; idx < pack; ++idx) {
results[idx] = std::max(results[idx], *(inputPtr + idx));
}
}
@ -2125,7 +2125,7 @@ void MNNCoreInt8FunctionInit() {
// pooling
gCoreFunc->MNNAvgPoolInt8 = MNNAvgPoolInt8;
gCoreFunc->MNNMaxPoolInt8 = MNNMaxPoolInt8;
// Norm
gCoreFunc->MNNNormInt8 = MNNNormInt8;
@ -2143,7 +2143,7 @@ void MNNCoreInt8FunctionInit() {
gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4<12, 4>;
// ConvDepthwise
gCoreFunc->ConvDepthwise3x3LineInt8_ARM82 = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3;
}
if (core->supportI8mm) {
// MatMul

View File

@ -15,7 +15,7 @@
#include "CommonOptFunction.h"
#include "core/Concurrency.h"
#include "core/TensorUtils.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
#include "MNN/AutoTime.hpp"
#include <math.h>
#ifdef MNN_USE_SSE

View File

@ -16,8 +16,8 @@
#include "core/TensorUtils.hpp"
#include "math/Vec.hpp"
#include "core/BufferAllocator.hpp"
#include "common/MemoryFormater.h"
#include "common/CommonCompute.hpp"
#include "core/MemoryFormater.h"
#include "core/CommonCompute.hpp"
using Vec4 = MNN::Math::Vec<float, 4>;
namespace MNN {

View File

@ -12,7 +12,7 @@
#include <map>
#include "core/Macro.h"
#include "math/Vec.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
using Vec4 = MNN::Math::Vec<float, 4>;
#define DEFAULT_UNIT 8

View File

@ -75,7 +75,6 @@ void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t s
const float* scale_ptr = param[4];
auto one_int16 = _mm256_set1_epi16(1);
auto offset_int8 = _mm256_set1_epi8(128);
auto _int4_signed_8 = _mm256_set1_ps(8);
for (int ci = 0; ci < dst_depth_quad; ++ci) {
float* dstZ = C + ci * pack * realSize;
const int8_t* weight = B + ci * weight_step;
@ -83,14 +82,13 @@ void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t s
auto zero = zero_ptr + ci * pack;
auto bias = bias_ptr + ci * pack;
__m256 alphaValue = _mm256_loadu_ps(alpha);
auto extra_sum = _mm256_mul_ps(_int4_signed_8, alphaValue);
for (int j = 0; j < realSize; ++j) {
const float* sums = sums_ptr + j;
const float* scale = scale_ptr + j;
float* dstX = dstZ + j * pack;
__m256 scaleValue = _mm256_set1_ps(scale[0]);
auto sum_val = _mm256_set1_ps(sums[0]);
__m256 biasValue = _mm256_add_ps(_mm256_loadu_ps(bias), _mm256_mul_ps(_mm256_sub_ps(_mm256_loadu_ps(zero), extra_sum), sum_val));
__m256 biasValue = _mm256_add_ps(_mm256_loadu_ps(bias), _mm256_mul_ps(_mm256_loadu_ps(zero), sum_val));
const int8_t* srcBatch = A + j * pack;
auto oc0123_int16 = _mm256_set1_epi16(0);
auto oc4567_int16 = _mm256_set1_epi16(0);
@ -103,10 +101,8 @@ void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t s
const uint8_t* weightZ = (uint8_t*)weight + k * weight_stride;
auto s0 = _mm256_castpd_si256(_mm256_broadcast_sd((double*)srcZ));
auto wi4 = _mm256_castps_si256(_mm256_loadu_ps((const float*)weightZ));
auto w_high = _mm256_and_si256(mask, _mm256_srli_epi16(wi4, 4));
auto w_low = _mm256_and_si256(mask, wi4);
auto w0_ = _mm256_unpacklo_epi8(w_high, w_low);
auto w1_ = _mm256_unpackhi_epi8(w_high, w_low);
auto w0_ = _mm256_and_si256(mask, _mm256_srli_epi16(wi4, 4));
auto w1_ = _mm256_and_si256(mask, wi4);
auto w0 = _mm256_permute2x128_si256(w0_, w1_, 0x20);
auto w1 = _mm256_permute2x128_si256(w0_, w1_, 0x31);
oc0123_int16 = _mm256_maddubs_epi16(w0, s0); // int16_t sum
@ -194,9 +190,11 @@ void _AVX_MNNGemmHybridInt8(float* C, const int8_t* A, const int8_t* B, size_t s
auto oc_04261537 = _mm256_add_epi32(oc_04261537_lo, oc_04261537_hi);
auto oc_0426 = _mm256_extractf128_si256(oc_04261537, 0);
auto oc_1537 = _mm256_extractf128_si256(oc_04261537, 1);
auto oc_0123 = _mm_unpacklo_epi32(oc_0426, oc_1537);
auto oc_4567 = _mm_unpackhi_epi32(oc_0426, oc_1537);
auto sum8 = _mm256_set_m128i(oc_0123, oc_4567);
auto oc_0145 = _mm_unpacklo_epi32(oc_0426, oc_1537);
auto oc_2367 = _mm_unpackhi_epi32(oc_0426, oc_1537);
auto oc_0123 = _mm_unpacklo_epi64(oc_0145, oc_2367);
auto oc_4567 = _mm_unpackhi_epi64(oc_0145, oc_2367);
auto sum8 = _mm256_set_m128i(oc_4567, oc_0123);
__m256 f0 = _mm256_cvtepi32_ps(sum8);
__m256 fs = _mm256_mul_ps(_mm256_mul_ps(f0, scaleValue), alphaValue);
fs = _mm256_add_ps(biasValue, fs);

View File

@ -9,7 +9,7 @@
#include <map>
#include "Vec8.hpp"
#include "FunctionSummary.hpp"
#include "common/MemoryFormater.h"
#include "core/MemoryFormater.h"
#define PACK_UNIT 8
namespace MNN {

View File

@ -656,8 +656,8 @@ static inline __m128i _load_int4_to_int8(const uint8_t* src) {
int32_t data[4];
int8_t temp[16];
for (int i = 0; i < 8; ++i) {
temp[2 * i] = (src[i] >> 4) - 8;
temp[2 * i +1] = (src[i] & c) - 8;
temp[2 * i] = (src[i] >> 4);
temp[2 * i +1] = (src[i] & c);
}
auto int8_tx16 = _mm_loadu_si128((const __m128i*)temp);
return int8_tx16;

View File

@ -1,32 +1,27 @@
#include "AllShader.hpp"
const char* shader_MetalReLU6_metal =
"kernel void relu6_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant float4 &minMax [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=clamp(in[int(gid)],(M)(minMax.x),(M)(minMax.y));\n"
"}\n"
"kernel void relu6_x4(const device M4 *in [[buffer(0)]],\n"
"struct Param {\n"
" float minV;\n"
" float maxV;\n"
" int size;\n"
" int remain;\n"
"};\n"
"kernel void relu6(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float4 &minMax [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=clamp(in[int(gid)],(M4)minMax.x,(M4)minMax.y);\n"
" constant Param &p [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<p.size) {\n"
" out[int(gid.x)]=clamp(in[int(gid.x)],(M4)p.minV,(M4)p.maxV);\n"
" }\n"
"}\n"
;
const char* shader_MetalReLU_metal =
"kernel void relu_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto V=in[int(gid)];\n"
" out[int(gid)]=fmax(V,(M)0)+fmin(V,(M)0)*M(slope);\n"
"}\n"
"kernel void relu_x4(const device M4 *in [[buffer(0)]],\n"
"kernel void relu(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto V=in[int(gid)];\n"
" out[int(gid)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(slope);\n"
" constant Param &p [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<p.size) {\n"
" auto V=in[int(gid.x)];\n"
" out[int(gid.x)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(p.minV);\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionDepthwise_metal =
@ -59,7 +54,7 @@ const char* shader_MetalConvolutionDepthwise_metal =
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice*cst.batch) return;\n"
" \n"
" int oz=gid.z % cst.slice;\n"
" int oz=gid.z/cst.batch;\n"
" int offset_x=(int)gid.x*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)gid.y*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
@ -198,16 +193,16 @@ const char* shader_MetalConvolution_metal =
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+(int)idx_c*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" auto z_out=out+idx_b*cst.output_size+(int)idx_c*cst.batch*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result=FLOAT4(biasTerms[idx_c]);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" auto in4=z_in[z*cst.input_size+y*dilation_h+x*cst.dilation_x];\n"
" auto in4=z_in[z*cst.input_size*cst.batch+y*dilation_h+x*cst.dilation_x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
@ -231,14 +226,14 @@ const char* shader_MetalConvolution_metal =
" bool valid_x=(int)(gid.x*2+1)<cst.output_width;\n"
" int offset_x=(int)gid.x*2-cst.pad_x;\n"
" int offset_y=(int)gid.y-cst.pad_y;\n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_flt=wt+uz[0]*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" FLOAT4 result4=0,result5=0,result6=0,result7=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += cst.input_size) {\n"
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += (cst.input_size*cst.batch)) {\n"
" auto in00=(offset_x<0 || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+0);\n"
" auto in01=(offset_x+1>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+1);\n"
" auto in02=(offset_x+2>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+2);\n"
@ -312,18 +307,18 @@ const char* shader_MetalConvolution_metal =
" if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" bool valid=(idx_w+1<cst.output_width);\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result1=result0;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<cst.kernel_y; y++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x];\n"
" auto in4_0=z_in[z*cst.input_size+y*cst.input_width];\n"
" auto in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width];\n"
" result0 += FLOAT4(in4_0*wt4);\n"
" for (auto x=1; x<cst.kernel_x; x++) {\n"
" in4_0=z_in[z*cst.input_size+y*cst.input_width+x];\n"
" in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width+x];\n"
" result1 += FLOAT4(in4_0*wt4);\n"
" wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" result0 += FLOAT4(in4_0*wt4);\n"
@ -352,9 +347,9 @@ const char* shader_MetalConvolution_metal =
" int3 uz=idx_w+int3(1,2,3);\n"
" bool3 valids=uz.xyz<cst.output_width;\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
@ -365,7 +360,7 @@ const char* shader_MetalConvolution_metal =
" auto wt4_0=wt_base[0];\n"
" auto wt4_1=wt_base[1];\n"
" auto wt4_2=wt_base[2];\n"
" auto z_in_base=z_in+z*cst.input_size+y*cst.input_width;\n"
" auto z_in_base=z_in+z*cst.batch*cst.input_size+y*cst.input_width;\n"
" auto in4_0=z_in_base[0];\n"
" result0 += FLOAT4(in4_0*wt4_0);\n"
" \n"
@ -420,14 +415,14 @@ const char* shader_MetalConvolution_metal =
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size) {\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
@ -471,14 +466,14 @@ const char* shader_MetalConvolution_metal =
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_slice*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_slice*cst.output_size+uz[0]*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size) {\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
@ -489,7 +484,7 @@ const char* shader_MetalConvolution_metal =
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if (valids) { z_out += cst.output_size; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
" if (valids) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
"}\n"
;
const char* shader_MetalReduction_metal =
@ -567,34 +562,9 @@ const char* shader_MetalBackend_metal =
"struct tensor_shape {\n"
" int size;\n"
" int channel;\n"
" int slice;\n"
" int batch;\n"
" int batch_slices;\n"
"};\n"
"kernel void version_func_002(const device uchar *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" // do nothing,just for verifying match between mnn and metallib\n"
"}\n"
"kernel void copy_byte(const device uchar *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void copy_float(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void upcast_float(const device M *in [[buffer(0)]],\n"
" device float *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"kernel void downcast_float(const device float *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=in[int(gid)];\n"
"}\n"
"struct Limit {\n"
" uint4 size;\n"
"};\n"
@ -616,8 +586,8 @@ const char* shader_MetalBackend_metal =
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NHWC_to_NC4HW4(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int b=gid.y % s.batch;\n"
" int z=gid.y/s.batch;\n"
" int c=z*4;\n"
" \n"
" auto off_in=in+b*s.size*s.channel+int(gid.x)*s.channel+c;\n"
@ -653,8 +623,8 @@ const char* shader_MetalBackend_metal =
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NC4HW4_to_NHWC(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int b=gid.y % s.batch;\n"
" int z=gid.y/s.batch;\n"
" int c=z*4;\n"
" auto off_in=in+int(gid.y)*s.size+int(gid.x);\n"
" auto off_out=out+b*s.size*s.channel+int(gid.x)*s.channel+c;\n"
@ -691,8 +661,8 @@ const char* shader_MetalBackend_metal =
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NCHW_to_NC4HW4(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int b=gid.y % s.batch;\n"
" int z=gid.y/s.batch;\n"
" int c=z*4;\n"
" \n"
" auto off_in=in+(b*s.channel+c)*s.size+int(gid.x);\n"
@ -728,8 +698,8 @@ const char* shader_MetalBackend_metal =
"}\n"
"template <typename IType,typename OType>\n"
"static inline void template_NC4HW4_to_NCHW(const device IType *in,device OType *out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int z=gid.y % s.slice;\n"
" int b=gid.y % s.batch;\n"
" int z=gid.y/s.batch;\n"
" int c=z*4;\n"
" \n"
" auto off_in=in+int(gid.y)*s.size+int(gid.x);\n"
@ -767,8 +737,8 @@ const char* shader_MetalBackend_metal =
"template<typename IType,typename OType>\n"
"static inline void template_NHWC_to_NCHW(const device IType* in,\n"
" device OType* out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int c4=gid.y % s.slice;\n"
" int b=gid.y % s.batch;\n"
" int c4=gid.y/s.batch;\n"
" \n"
" auto in_off=(b*s.size+gid.x)*s.channel+c4*4;\n"
" auto out_off=(b*s.channel+c4*4)*s.size+gid.x;\n"
@ -793,8 +763,8 @@ const char* shader_MetalBackend_metal =
"template<typename IType,typename OType>\n"
"static inline void template_NCHW_to_NHWC(const device IType* in,\n"
" device OType* out,constant tensor_shape &s,uint2 gid) {\n"
" int b=gid.y/s.slice;\n"
" int c4=gid.y % s.slice;\n"
" int b=gid.y % s.batch;\n"
" int c4=gid.y/s.batch;\n"
" \n"
" auto in_off=(b*s.channel+c4*4)*s.size+gid.x;\n"
" auto out_off=(b*s.size+gid.x)*s.channel+c4*4;\n"
@ -1452,7 +1422,7 @@ const char* shader_MetalScale_metal =
" const device float4 *biasTerms [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.steps*s.batch) return;\n"
" int z=gid.y % s.steps;\n"
" int z=gid.y/s.batch;\n"
" out[int(gid.y)*s.size+int(gid.x)] =\n"
" in [int(gid.y)*s.size+int(gid.x)]*M4(scales[z])+M4(biasTerms[z]);\n"
"}\n"
@ -1494,8 +1464,8 @@ const char* shader_MetalDeconvolution_metal =
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" int b=gid.z/cst.output_slice;\n"
" int o=gid.z % cst.output_slice;\n"
" int b=gid.z % cst.batch;\n"
" int o=gid.z/cst.batch;\n"
" FLOAT4 result=FLOAT4(biasTerms[o]);\n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
@ -1513,12 +1483,12 @@ const char* shader_MetalDeconvolution_metal =
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto o_wt=wt+o*cst.input_slice*cst.kernel_size;\n"
" auto b_in=in+b*cst.input_slice*cst.input_size;\n"
" auto b_in=in+b*cst.input_size;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=o_wt[z*cst.kernel_size+ky*cst.kernel_x+kx];\n"
" auto in4=b_in[z*cst.input_size+iy*cst.input_width+ix];\n"
" auto in4=b_in[z*cst.input_size*cst.batch+iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
@ -1534,7 +1504,7 @@ const char* shader_MetalDeconvolution_metal =
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z % cst.input_slice)]);\n"
" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z/cst.batch)]);\n"
" \n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
@ -1631,10 +1601,11 @@ const char* shader_MetalROIPooling_metal =
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_batch;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int slices;\n"
" int batch;\n"
" float spatial_scale;\n"
"};\n"
"kernel void ROI_pooling(const device M4 *in [[buffer(0)]],\n"
@ -1644,10 +1615,10 @@ const char* shader_MetalROIPooling_metal =
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;\n"
" \n"
" int ob=gid.z/s.slices;\n"
" int iz=gid.z % s.slices;\n"
" int ob=gid.z % s.batch;\n"
" int iz=gid.z/s.batch;\n"
" \n"
" auto b_roi=roi+ob*8; // roundup(5,4)=8\n"
" auto b_roi=roi+ob*5;\n"
" int ib=int(b_roi[0]);\n"
" int x1=round(float(b_roi[1])*s.spatial_scale);\n"
" int y1=round(float(b_roi[2])*s.spatial_scale);\n"
@ -1656,8 +1627,8 @@ const char* shader_MetalROIPooling_metal =
" \n"
" int roi_w=max(x2-x1+1,1);\n"
" int roi_h=max(y2-y1+1,1);\n"
" auto bin_size_w=(M)roi_w/s.output_width;\n"
" auto bin_size_h=(M)roi_h/s.output_height;\n"
" float bin_size_w=(float)roi_w/(float)s.output_width;\n"
" float bin_size_h=(float)roi_h/(float)s.output_height;\n"
" \n"
" int w_start=clamp(x1+(int)floor(gid.x*bin_size_w) ,0,s.input_width);\n"
" int w_end=clamp(x1+(int)ceil((gid.x+1)*bin_size_w),0,s.input_width);\n"
@ -1665,7 +1636,7 @@ const char* shader_MetalROIPooling_metal =
" int h_end=clamp(y1+(int)ceil((gid.y+1)*bin_size_h),0,s.input_height);\n"
" \n"
" int is_empty=(h_end <= h_start) || (w_end <= w_start);\n"
" auto z_in=in+(ib*s.slices+iz)*s.input_size;\n"
" auto z_in=in+(ib+iz*s.input_batch)*s.input_size;\n"
" auto max4=is_empty ? 0 : z_in[h_start*s.input_width+w_start];\n"
" for (int y=h_start; y<h_end; y++) {\n"
" auto y_in=z_in+y*s.input_width;\n"
@ -1690,30 +1661,6 @@ const char* shader_MetalConvolution1x1_metal =
" int batch;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv1x1_w1h1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result0=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
"}\n"
"kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
@ -1725,8 +1672,8 @@ const char* shader_MetalConvolution1x1_metal =
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
@ -1741,7 +1688,7 @@ const char* shader_MetalConvolution1x1_metal =
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" \n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
@ -1755,19 +1702,18 @@ const char* shader_MetalConvolution1x1_metal =
" const device MNN::char4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device float4 *dequantScale [[buffer(5)]],\n"
" const device float4 *dequantBias [[buffer(6)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" auto scale=FLOAT4(dequantScale[uz]);\n"
" auto dequant_bias=FLOAT4(dequantBias[uz]);\n"
" auto dequant_bias=FLOAT4(dequantScale[uz+cst.output_slice]);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=(FLOAT4)*xy_in0;\n"
" auto in41=(FLOAT4)*(xy_in0+1);\n"
@ -1779,20 +1725,14 @@ const char* shader_MetalConvolution1x1_metal =
" FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n"
" FLOAT4x4 w_dequant;\n"
" for (int i=0; i<4; ++i) {\n"
" FLOAT4 w4=w_fp32[i];\n"
" FLOAT4 res;\n"
" for (int j=0; j<4; ++j) {\n"
" float wf=w4[j]*scale[i]+dequant_bias[i];\n"
" res[j]=wf;\n"
" }\n"
" w_dequant[i]=res;\n"
" w_dequant[i]=w_fp32[i]*scale[i]+dequant_bias[i];\n"
" }\n"
" \n"
" result0 += FLOAT4(in40*w_dequant);\n"
" result1 += FLOAT4(in41*w_dequant);\n"
" result2 += FLOAT4(in42*w_dequant);\n"
" result3 += FLOAT4(in43*w_dequant);\n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" \n"
" /* true */ \n"
@ -1807,19 +1747,18 @@ const char* shader_MetalConvolution1x1_metal =
" const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device float4 *dequantScale [[buffer(5)]],\n"
" const device float4 *dequantBias [[buffer(6)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" auto scale=FLOAT4(dequantScale[uz]);\n"
" auto dequant_bias=FLOAT4(dequantBias[uz]);\n"
" auto dequant_bias=FLOAT4(dequantScale[uz+cst.output_slice]);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=(FLOAT4)*xy_in0;\n"
" auto in41=(FLOAT4)*(xy_in0+1);\n"
@ -1833,11 +1772,7 @@ const char* shader_MetalConvolution1x1_metal =
" for (int i=0; i<4; ++i) {\n"
" // M4 w4=M4(w_fp32[i]);\n"
" FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n"
" FLOAT4 res;\n"
" for (int j=0; j<4; ++j) {\n"
" float wf=w4[j]*scale[i]+dequant_bias[i];\n"
" res[j]=wf;\n"
" }\n"
" FLOAT4 res=w4*scale[i]+dequant_bias[i];\n"
" w_dequant[i]=res;\n"
" }\n"
" \n"
@ -1845,7 +1780,7 @@ const char* shader_MetalConvolution1x1_metal =
" result1 += FLOAT4(in41*w_dequant);\n"
" result2 += FLOAT4(in42*w_dequant);\n"
" result3 += FLOAT4(in43*w_dequant);\n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" \n"
" /* true */ \n"
@ -1874,8 +1809,8 @@ const char* shader_MetalConvolution1x1_metal =
" int rx=gid.x*CONV_UNROLL_L;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_slice*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_slice*cst.output_size+uz*cst.output_size+rx;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.batch*cst.output_size+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
@ -1898,7 +1833,7 @@ const char* shader_MetalConvolution1x1_metal =
" result5 += FLOAT4(in45*w);\n"
" result6 += FLOAT4(in46*w);\n"
" result7 += FLOAT4(in47*w);\n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
@ -1909,71 +1844,20 @@ const char* shader_MetalConvolution1x1_metal =
" if (computeSize>6) {xy_out[6]=activate(M4(result6),cst.activation); }\n"
" if (computeSize>7) {xy_out[7]=activate(M4(result7),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w4h2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=gid.y << 1;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in42=xy_in0[2];\n"
" auto in43=xy_in0[3];\n"
" auto in44=xy_in0[cst.output_width+0];\n"
" auto in45=xy_in0[cst.output_width+1];\n"
" auto in46=xy_in0[cst.output_width+2];\n"
" auto in47=xy_in0[cst.output_width+3];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" result6 += FLOAT4(in46*w);\n"
" result7 += FLOAT4(in47*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,2);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+2]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+3]=activate(M4(result7),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w4h4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*4 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=gid.y << 2;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" if ((int)gid.x*16 >= cst.output_width || (int)gid.y >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 4;\n"
" int idx_h=0;\n"
" int idx_c=gid.y/cst.batch;\n"
" int idx_b=gid.y % cst.batch;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result00=biasValue,result01=biasValue,result02=biasValue,result03=biasValue;\n"
" FLOAT4 result10=biasValue,result11=biasValue,result12=biasValue,result13=biasValue;\n"
@ -1984,19 +1868,19 @@ const char* shader_MetalConvolution1x1_metal =
" auto in01=xy_in0[1];\n"
" auto in02=xy_in0[2];\n"
" auto in03=xy_in0[3];\n"
" auto in10=xy_in0[cst.output_width+0];\n"
" auto in11=xy_in0[cst.output_width+1];\n"
" auto in12=xy_in0[cst.output_width+2];\n"
" auto in13=xy_in0[cst.output_width+3];\n"
" auto in10=xy_in0[4];\n"
" auto in11=xy_in0[5];\n"
" auto in12=xy_in0[6];\n"
" auto in13=xy_in0[7];\n"
" \n"
" auto in20=xy_in0[cst.output_width+cst.output_width+0];\n"
" auto in21=xy_in0[cst.output_width+cst.output_width+1];\n"
" auto in22=xy_in0[cst.output_width+cst.output_width+2];\n"
" auto in23=xy_in0[cst.output_width+cst.output_width+3];\n"
" auto in30=xy_in0[cst.output_width+cst.output_width+cst.output_width+0];\n"
" auto in31=xy_in0[cst.output_width+cst.output_width+cst.output_width+1];\n"
" auto in32=xy_in0[cst.output_width+cst.output_width+cst.output_width+2];\n"
" auto in33=xy_in0[cst.output_width+cst.output_width+cst.output_width+3];\n"
" auto in20=xy_in0[8];\n"
" auto in21=xy_in0[9];\n"
" auto in22=xy_in0[10];\n"
" auto in23=xy_in0[11];\n"
" auto in30=xy_in0[12];\n"
" auto in31=xy_in0[13];\n"
" auto in32=xy_in0[14];\n"
" auto in33=xy_in0[15];\n"
" auto w=xy_wt[z];\n"
" result00 += FLOAT4(in00*w);\n"
" result01 += FLOAT4(in01*w);\n"
@ -2016,33 +1900,25 @@ const char* shader_MetalConvolution1x1_metal =
" result32 += FLOAT4(in32*w);\n"
" result33 += FLOAT4(in33*w);\n"
" \n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" int widthSize=min(cst.output_width-idx_w,16);\n"
" /* true */ *xy_out=activate(M4(result00),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result01),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result02),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result03),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,4);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result10),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result11),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+2]=activate(M4(result12),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+3]=activate(M4(result13),cst.activation); }\n"
" }\n"
" if(heightSize>2) {\n"
" /* true */ {xy_out[cst.output_width+cst.output_width+0]=activate(M4(result20),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+cst.output_width+1]=activate(M4(result21),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+cst.output_width+2]=activate(M4(result22),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+cst.output_width+3]=activate(M4(result23),cst.activation); }\n"
" }\n"
" if(heightSize>3) {\n"
" /* true */ {xy_out[cst.output_width+cst.output_width+cst.output_width+0]=activate(M4(result30),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+cst.output_width+cst.output_width+1]=activate(M4(result31),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_width+cst.output_width+cst.output_width+2]=activate(M4(result32),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_width+cst.output_width+cst.output_width+3]=activate(M4(result33),cst.activation); }\n"
" }\n"
" if (widthSize>4) {xy_out[4]=activate(M4(result10),cst.activation); }\n"
" if (widthSize>5) {xy_out[5]=activate(M4(result11),cst.activation); }\n"
" if (widthSize>6) {xy_out[6]=activate(M4(result12),cst.activation); }\n"
" if (widthSize>7) {xy_out[7]=activate(M4(result13),cst.activation); }\n"
" if (widthSize>8) {xy_out[8]=activate(M4(result20),cst.activation); }\n"
" if (widthSize>9) {xy_out[9]=activate(M4(result21),cst.activation); }\n"
" if (widthSize>10) {xy_out[10]=activate(M4(result22),cst.activation); }\n"
" if (widthSize>11) {xy_out[11]=activate(M4(result23),cst.activation); }\n"
" if (widthSize>12) {xy_out[12]=activate(M4(result30),cst.activation); }\n"
" if (widthSize>13) {xy_out[13]=activate(M4(result31),cst.activation); }\n"
" if (widthSize>14) {xy_out[14]=activate(M4(result32),cst.activation); }\n"
" if (widthSize>15) {xy_out[15]=activate(M4(result33),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w2c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
@ -2050,17 +1926,17 @@ const char* shader_MetalConvolution1x1_metal =
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z*2 >= cst.batch*cst.output_slice) return;\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=(gid.z % channel_pack) << 1;\n"
" int idx_b=gid.z/channel_pack;\n"
" int idx_h=0;\n"
" int idx_c=(gid.y % channel_pack) << 1;\n"
" int idx_b=gid.y/channel_pack;\n"
" \n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
@ -2074,7 +1950,7 @@ const char* shader_MetalConvolution1x1_metal =
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in40*w1);\n"
" result5 += FLOAT4(in41*w1);\n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
@ -2082,65 +1958,26 @@ const char* shader_MetalConvolution1x1_metal =
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ {xy_out[cst.output_size+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size+1]=activate(M4(result5),cst.activation); }\n"
" /* true */ {xy_out[cst.output_size*cst.batch +0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result5),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w2h2(const device M4 *in [[buffer(0)]],\n"
"kernel void conv1x1_w4c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y << 1;\n"
" int idx_c=gid.z % cst.output_slice;\n"
" int idx_b=gid.z/cst.output_slice;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result0=biasValue,result1=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in44=xy_in0[cst.output_width+0];\n"
" auto in45=xy_in0[cst.output_width+1];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" xy_in0 += cst.input_size;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,2);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w2h2c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.output_height || (int)gid.z*2 >= cst.batch*cst.output_slice) return;\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y << 1;\n"
" int idx_c=(gid.z % channel_pack) << 1;\n"
" int idx_b=gid.z/channel_pack;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=0;\n"
" int idx_c=(gid.y % channel_pack) << 1;\n"
" int idx_b=gid.y/channel_pack;\n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_slice*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_slice*cst.output_size+idx_c*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
@ -2150,8 +1987,8 @@ const char* shader_MetalConvolution1x1_metal =
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in44=xy_in0[cst.output_width+0];\n"
" auto in45=xy_in0[cst.output_width+1];\n"
" auto in44=xy_in0[2];\n"
" auto in45=xy_in0[3];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
@ -2162,27 +1999,20 @@ const char* shader_MetalConvolution1x1_metal =
" result3 += FLOAT4(in41*w1);\n"
" result6 += FLOAT4(in44*w1);\n"
" result7 += FLOAT4(in45*w1);\n"
" xy_in0 += cst.input_size;\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int heightSize=min(cst.output_height-idx_h,2);\n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_width+0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_width+1]=activate(M4(result5),cst.activation); }\n"
" }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result5),cst.activation); }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ xy_out[cst.output_size]=activate(M4(result2),cst.activation);\n"
" if (widthSize>1) {xy_out[cst.output_size+1]=activate(M4(result3),cst.activation); }\n"
" \n"
" if(heightSize>1) {\n"
" /* true */ {xy_out[cst.output_size+cst.output_width+0]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size+cst.output_width+1]=activate(M4(result7),cst.activation); }\n"
" }\n"
" /* true */ xy_out[cst.output_size*cst.batch]=activate(M4(result2),cst.activation);\n"
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result3),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_size*cst.batch +2]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_size*cst.batch +3]=activate(M4(result7),cst.activation); }\n"
" }\n"
"}\n"
;
@ -2341,7 +2171,7 @@ const char* shader_MetalPReLU_metal =
" uint3 gid [[thread_position_in_grid]]) { // size,slice,batch\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;\n"
" \n"
" int z=gid.z*s.slice+gid.y;\n"
" int z=gid.z+gid.y*s.batch;\n"
" auto v4=in[z*s.size+int(gid.x)];\n"
" out[z*s.size+int(gid.x)]=select(v4,M4(slope[int(gid.y)])*v4,signbit(v4));\n"
"}\n"

View File

@ -1,7 +1,6 @@
#ifndef MNN_METAL_SHADER_AUTO_GENERATE_H
#define MNN_METAL_SHADER_AUTO_GENERATE_H
extern const char* shader_MetalReLU6_metal;
extern const char* shader_MetalReLU_metal;
extern const char* shader_MetalConvolutionDepthwise_metal;
extern const char* shader_MetalConvolutionActivation_metal;
extern const char* shader_MetalConvolution_metal;

View File

@ -27,14 +27,6 @@ typedef enum {
CPUTransparent
} MetalAccess;
typedef struct {
/** wrap size */
NSUInteger threadExecutionWidth;
/** max threads per thread group */
NSUInteger maxThreadsPerThreadgroup;
/** run concurrently on z axis or not */
BOOL zAxisProtected;
} MetalBandwidth;
}
@interface MNNMetalContext : NSObject
@ -62,14 +54,6 @@ typedef struct {
- (id<MTLBuffer>)newDeviceBuffer:(NSUInteger)size bytes:(const void *)bytes access:(MNN::MetalAccess)access;
/**
* @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline.
* @param name pipline name
* @param encoder command encoder
* @return bandwidth info for function
*/
- (MNN::MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder fp16:(BOOL)fp16;
/**
* @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline.
* @param name pipline name
@ -86,16 +70,6 @@ typedef struct {
- (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device;
/**
* @brief dispatch encoder with default settings
* @param encoder command encoder
* @param threads threads size
* @param bandwidth bandwidth
*/
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
threads:(MTLSize)threads
bandwidth:(MNN::MetalBandwidth)bandwidth;
/**
* @brief dispatch encoder with designated threads per threadgroup
* @param encoder command encoder
@ -103,10 +77,6 @@ typedef struct {
* @param threadsPerGroup thread size per group
* @param bandwidth bandwidth
*/
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
threads:(MTLSize)threads
threadsPerGroup:(MTLSize)threadsPerGroup
bandwidth:(MNN::MetalBandwidth)bandwidth;
- (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name fp16:(BOOL)fp16;
- (id<MTLComputePipelineState>)pipelineWithSourceOption:(NSString *)source name:(NSString *)name options:(MTLCompileOptions *)options;
- (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) pipeline threads:(MTLSize)threads;

View File

@ -181,25 +181,6 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return result;
}
- (MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder fp16:(BOOL)fp16 {
id<MTLComputePipelineState> pipeline = [self pipelineWithName:name fp16:fp16];
MNN_ASSERT(nil != pipeline);
[encoder setComputePipelineState:pipeline];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
if (!name) {
} else if (!encoder.label) {
encoder.label = name;
} else {
NSArray *components = [encoder.label componentsSeparatedByString:@","];
if (![components containsObject:name]) {
components = [components arrayByAddingObject:name];
}
encoder.label = [components componentsJoinedByString:@","];
}
#endif
return {pipeline.threadExecutionWidth, pipeline.maxTotalThreadsPerThreadgroup, NO};
}
- (NSUInteger)timeUsed:(id<MTLCommandBuffer>)buffer {
// Get ns precision time
auto start = mach_absolute_time();
@ -393,37 +374,6 @@ static NSUInteger smallest_log2(NSUInteger integer) {
}
- (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) bw threads:(MTLSize)t {
if (bw.maxTotalThreadsPerThreadgroup > 64) {
auto res = MTLSizeMake(8, 8, 8);
int reduceNumber = 0;
if (t.depth < 4) {
res.depth = 1;
reduceNumber++;
}
if (t.width < 4) {
res.width = 1;
reduceNumber++;
}
if (t.height < 4) {
res.height = 1;
reduceNumber++;
}
if (reduceNumber == 0) {
return MTLSizeMake(4, 4, 4);
}
if (reduceNumber == 2) {
if (res.width > 1) {
res.width = 64;
}
if (res.height > 1) {
res.height = 64;
}
if (res.depth > 1) {
res.depth = 64;
}
}
return res;
}
auto pwarp = smallest_log2(bw.threadExecutionWidth);
auto px = smallest_log2(t.width), sx = (NSUInteger)ceil(log2(t.width));
auto py = smallest_log2(t.height), sy = (NSUInteger)ceil(log2(t.height));
@ -463,91 +413,6 @@ static NSUInteger smallest_log2(NSUInteger integer) {
}
- (MTLSize)threadsPerGroupWithThreads:(MTLSize)t bandwidth:(MetalBandwidth)bw {
auto pwarp = smallest_log2(bw.threadExecutionWidth);
auto px = smallest_log2(t.width), sx = (NSUInteger)ceil(log2(t.width));
auto py = smallest_log2(t.height), sy = (NSUInteger)ceil(log2(t.height));
// accurately match on x
if (px >= pwarp) {
return {bw.threadExecutionWidth, 1, 1};
}
// accurately match on xy
else if (px + py >= pwarp && sx < pwarp / 2) {
NSUInteger x = pow(2, px);
return {x, bw.threadExecutionWidth / x, 1};
}
// similarly match on x
else if (sx >= pwarp) {
return {bw.threadExecutionWidth, 1, 1};
}
// similarly match on xy
else if (sx + sy >= pwarp) {
NSUInteger x = pow(2, sx);
return {x, bw.threadExecutionWidth / x, 1};
}
// on xyz (for most shaders do not protect gid.z, z axis must be accurately match)
auto pz = smallest_log2(t.depth);
auto sz = bw.zAxisProtected ? ceil(log2(t.depth)) : pz;
if (px + py + pz >= pwarp) {
NSUInteger x = pow(2, px), y = pow(2, py);
return {x, y, bw.threadExecutionWidth / x / y};
} else if (sx + sy + sz >= pwarp) {
NSUInteger x = pow(2, sx), z = pow(2, MIN(sz, pwarp - sx));
return {x, bw.threadExecutionWidth / x / z, z};
} else {
NSUInteger z = pow(2, sz);
return {t.width, t.height, z};
}
}
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
threads:(MTLSize)threads
bandwidth:(MetalBandwidth)bandwidth {
[self dispatchEncoder:encoder
threads:threads
threadsPerGroup:[self threadsPerGroupWithThreads:threads bandwidth:bandwidth]
bandwidth:bandwidth];
}
- (void)dispatchEncoder:(id<MTLComputeCommandEncoder>)encoder
threads:(MTLSize)threads
threadsPerGroup:(MTLSize)threadsPerGroup
bandwidth:(MetalBandwidth)bandwidth {
#if MNN_METAL_DEBUG
if (threads.width == 0 || threads.height == 0 || threads.depth == 0 || threadsPerGroup.width == 0 ||
threadsPerGroup.height == 0 || threadsPerGroup.depth == 0) {
printf("[METAL] dispatch error %td %td %td / %td %td %td\n", threads.width, threads.height, threads.depth,
threadsPerGroup.width, threadsPerGroup.height, threadsPerGroup.depth);
}
#endif
// NSLog(@"dispatch {%td %td %td} with {%td %td %td}",
// threads.width, threads.height, threads.depth,
// threadsPerGroup.width, threadsPerGroup.height, threadsPerGroup.depth);
threadsPerGroup.width = MIN(threadsPerGroup.width, bandwidth.maxThreadsPerThreadgroup);
threadsPerGroup.height = MIN(threadsPerGroup.height, bandwidth.maxThreadsPerThreadgroup);
threadsPerGroup.depth = MIN(threadsPerGroup.depth, bandwidth.maxThreadsPerThreadgroup);
#ifdef MNN_BUILD_FOR_IOS
if (@available(iOS 11.0, *)) {
if ([_device supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily4_v1]) {
[encoder dispatchThreads:threads threadsPerThreadgroup:threadsPerGroup];
return;
}
}
#endif
MTLSize groups = {
UP_DIV(threads.width, threadsPerGroup.width), UP_DIV(threads.height, threadsPerGroup.height),
UP_DIV(threads.depth, threadsPerGroup.depth),
};
MNN_ASSERT(threadsPerGroup.width >= 1);
MNN_ASSERT(threadsPerGroup.height >= 1);
MNN_ASSERT(threadsPerGroup.depth >= 1);
[encoder dispatchThreadgroups:groups threadsPerThreadgroup:threadsPerGroup];
}
#if MNN_METAL_DEBUG
#pragma mark debug
- (void)printTensor:(const Tensor *)tensor {

View File

@ -1,3 +1,11 @@
//
// MetalArgMax.mm
// MNN
//
// Created by MNN on 2023/12/29.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "core/Macro.h"
#import "MetalCast.hpp"
#import "MetalBackend.hpp"

View File

@ -0,0 +1,425 @@
//
// MetalAttention.mm
// MNN
//
// Created by MNN on b'2024/04/29'.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <set>
#import "core/Macro.h"
#import "MetalCast.hpp"
#import "MetalBackend.hpp"
#import "MNNMetalContext.h"
#include "MNN_generated.h"
#if MNN_METAL_ENABLED
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
static const char* gMatMulDivMask = R"metal(
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct Param {
int query_seq_len;
int key_seq_len;
int head_num;
int head_dim;
float scale;
};
kernel void main0(const device T* input0 [[buffer(0)]],
const device T* input1 [[buffer(1)]],
device T* output [[buffer(2)]],
device T* past_key [[buffer(3)]],
const device int* mask [[buffer(4)]],
constant Param& param [[buffer(5)]],
uint3 gid[[thread_position_in_grid]]) {
const int x = gid.x; // query_seq_len
const int y = gid.y; // head_num
const int z = gid.z; // key_seq_len
if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) {
return;
}
int query_seq_len = param.query_seq_len;
int key_seq_len = param.key_seq_len;
int head_num = param.head_num;
int head_dim = param.head_dim;
const int offset = head_num * head_dim;
const int offset_head = y * head_dim;
const device T* A_offset = input0 + x * offset + offset_head;
device T* Pastkey_offset = past_key + z * offset + offset_head;
float Vscale = (float)param.scale;
#ifdef FOR_PREFILL
device const T* B_offset = input1 + z * offset + offset_head;
const int output_offset = y * query_seq_len * key_seq_len;
float out0 = 0.0;
for(int i = 0; i < head_dim; ++i){
float A = (float)(A_offset[i]);
float B = (float)(B_offset[i]);
out0 += B * A;
Pastkey_offset[i] = (T)B;
}
out0 *= Vscale;
out0 = mask[((x + 0) * key_seq_len + (z + 0))] == 0 ? -FLT_MAX : out0;
output[output_offset + x * key_seq_len + z] = (T)out0;
#else
const device T *B_offset = input1 + offset_head;
float out = 0.0;
if (z == key_seq_len - 1) {
for(int i = 0; i < head_dim; ++i){
float A = (float)(A_offset[i]);
float B = (float)(B_offset[i]);
out += B * A;
Pastkey_offset[i] = (T)B;
}
} else {
for(int i = 0; i < head_dim; ++i){
float A = A_offset[i];
float B = (float)Pastkey_offset[i];
out += A * B;
}
}
out *= Vscale;
output[y + z * head_num] = (T)out;
#endif
}
)metal";
static const char* gMatMulQKV = R"metal(
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct Param {
int query_seq_len;
int key_seq_len;
int head_num;
int head_dim;
float scale;
};
kernel void main0(const device T* input0 [[buffer(0)]],
const device T* input1 [[buffer(1)]],
device T* output [[buffer(2)]],
device T* past_value [[buffer(3)]],
constant Param& param [[buffer(4)]],
uint3 gid[[thread_position_in_grid]]) {
const int x = gid.x; // query_seq_len
const int y = gid.y; // head_num
const int z = gid.z; // head_dim
if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) {
return;
}
int qk_seq_len = param.query_seq_len;
int value_seq_len = param.key_seq_len;
int head_num = param.head_num;
int head_dim = param.head_dim;
const int offset = head_num * head_dim;
const int offset_head = y * head_dim + z;
#ifdef FOR_PREFILL
device const T *A_offset = input0 + (y * qk_seq_len + x) * value_seq_len;
device const T *B_offset = input1 + offset_head;
device T *Pastvalue_offset = past_value + offset_head;
float out = 0.0;
for(int i = 0; i < value_seq_len; ++i){
float A0 = (float)A_offset[i];
float B = (float)B_offset[i*offset];
out += A0 * B;
Pastvalue_offset[i*offset] = B;
}
output[ x * offset + (y * head_dim + z)] = out;
#else
device const T *A_offset = input0 + y;
device const T *B_offset = input1 + offset_head;
device T *Pastvalue_offset = past_value + offset_head;
float out = 0;
for(int i = 0; i < value_seq_len - 1; ++i){
float A = (float)A_offset[i * head_num];
float B = (float)Pastvalue_offset[i * offset];
out += A * B;
}
out += (float)A_offset[(value_seq_len - 1)*head_num] * (float)B_offset[0];
Pastvalue_offset[(value_seq_len - 1)*offset] = B_offset[0];
output[(y * head_dim + z)] = (T)out;
#endif
}
)metal";
namespace MNN {
class AttentionBufExecution : public MetalExecution {
public:
struct SharedCache {
std::shared_ptr<Tensor> mPastKey;
std::shared_ptr<Tensor> mPastValue;
int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0;
};
AttentionBufExecution(Backend *backend, bool kv_cache);
virtual ~AttentionBufExecution() = default;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override {
if (nullptr == dst) {
return true;
}
auto exe = new AttentionBufExecution(bn, mKVCache);
exe->mCache = mCache;
*dst = exe;
return true;
}
private:
void _init();
void reallocKVCache();
bool mKVCache;
std::shared_ptr<SharedCache> mCache;
float mScale;
const int mExpandChunk = 64;
bool mIsDecode = false;
std::shared_ptr<Tensor> mTempQK, mTempSoftMax;
int mNumHead = 0, mHeadDim = 0, mValueH = 0;
id<MTLComputePipelineState> mKernel_softmax;
id<MTLComputePipelineState> mKernel_qk;
id<MTLComputePipelineState> mKernel_qkv;
id<MTLComputePipelineState> mKernelPrefill_qk;
id<MTLComputePipelineState> mKernelPrefill_qkv;
id<MTLBuffer> mParamQKV;
id<MTLBuffer> mParamSoftmax;
};
struct Param {
int query_seq_len;
int key_seq_len;
int head_num;
int head_dim;
float scale;
};
AttentionBufExecution::AttentionBufExecution(Backend *backend, bool kv_cahce)
: MetalExecution(backend) , mKVCache(kv_cahce){
_init();
}
void AttentionBufExecution::_init() {
mCache.reset(new SharedCache);
auto mtbn = static_cast<MetalBackend *>(backend());
auto context = (__bridge MNNMetalContext *)mtbn->context();
mKernel_softmax = [context pipelineWithName:@"softmax_plane" fp16:mtbn->useFp16InsteadFp32()];
mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly];
mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
auto rt = mtbn->runtime();
std::string T = "float";
if (mtbn->useFp16InsteadFp32()) {
T = "half";
}
std::vector<std::string> keys = {
"matmul_qk_div_mask",
T
};
auto pipeline = rt->findPipeline(keys);
if (nil == pipeline) {
// Rebuild Pipeline
MTLCompileOptions *decodeOption = [[MTLCompileOptions alloc] init];
decodeOption.preprocessorMacros = @{
@"T" : @(T.c_str()),
};
MTLCompileOptions* encodeOption = [[MTLCompileOptions alloc] init];
encodeOption.preprocessorMacros = @{
@"T" : @(T.c_str()),
@"FOR_PREFILL": @"1"
};
mKernel_qk = mtbn->makeComputePipelineWithSourceOption(gMatMulDivMask, "main0", decodeOption);
mKernelPrefill_qk = mtbn->makeComputePipelineWithSourceOption(gMatMulDivMask, "main0", encodeOption);
mKernel_qkv = mtbn->makeComputePipelineWithSourceOption(gMatMulQKV, "main0", decodeOption);
mKernelPrefill_qkv = mtbn->makeComputePipelineWithSourceOption(gMatMulQKV, "main0", encodeOption);
rt->insertPipeline({"matmul_qk_div_mask", T}, mKernel_qk);
rt->insertPipeline({"matmul_qk_div_mask", T, "PREFILL"}, mKernelPrefill_qk);
rt->insertPipeline({"matmul_qkv", T}, mKernel_qkv);
rt->insertPipeline({"matmul_qkv", T, "PREFILL"}, mKernelPrefill_qkv);
} else {
mKernel_qk = rt->findPipeline({"matmul_qk_div_mask", T});
mKernelPrefill_qk = rt->findPipeline({"matmul_qk_div_mask", T, "PREFILL"});
mKernel_qkv = rt->findPipeline({"matmul_qkv", T});
mKernelPrefill_qkv = rt->findPipeline({"matmul_qkv", T, "PREFILL"});
}
MNN_ASSERT(nil != mKernel_qk);
MNN_ASSERT(nil != mKernel_qkv);
MNN_ASSERT(nil != mKernelPrefill_qk);
MNN_ASSERT(nil != mKernelPrefill_qkv);
}
void AttentionBufExecution::reallocKVCache() {
if (mCache->mPastLength < mCache->mMaxLength || nullptr == mTempQK || (!mIsDecode)) {
if (mIsDecode) {
mTempQK.reset(Tensor::createDevice<float>({mNumHead, mCache->mMaxLength}));
mTempSoftMax.reset(Tensor::createDevice<float>({mNumHead, mCache->mMaxLength}));
} else {
mTempQK.reset(Tensor::createDevice<float>({mNumHead, mCache->mPastLength, mCache->mPastLength}));
mTempSoftMax.reset(Tensor::createDevice<float>({mNumHead, mCache->mPastLength, mCache->mPastLength}));
}
backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC);
}
if (!mKVCache || mCache->mPastLength < mCache->mMaxLength) {
return;
}
auto mtbn = static_cast<MetalBackend *>(backend());
int byte = 4;
if(mtbn->useFp16InsteadFp32()) {
byte = 2;
}
bool needCopy = mCache->mMaxLength > 0;
size_t old_size = mNumHead * mCache->mMaxLength * mHeadDim * byte;
mCache->mMaxLength = mCache->mPastLength + mExpandChunk;
// past_key: [1, numhead, headdim, maxlen]
auto new_key = Tensor::createDevice<float>({mCache->mMaxLength, mNumHead, mHeadDim});
// past_value: [1, numhead, maxlen, headdim]
auto new_value = Tensor::createDevice<float>({mCache->mMaxLength, mNumHead, mHeadDim});
size_t size = mNumHead * mCache->mMaxLength * mHeadDim * byte;
backend()->onAcquireBuffer(new_key, Backend::STATIC);
backend()->onAcquireBuffer(new_value, Backend::STATIC);
if (needCopy) {
auto newKeyBuf = MetalBackend::getBuffer(new_key);
auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second;
auto keyBuf = MetalBackend::getBuffer(mCache->mPastKey.get());
auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;;
::memcpy(new_key_ptr, key_ptr, old_size);
auto newValueBuf = MetalBackend::getBuffer(new_value);
auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second;
auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get());
auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second;
::memcpy(new_value_ptr, value_ptr, old_size);
}
mCache->mPastKey.reset(new_key);
mCache->mPastValue.reset(new_value);
}
void AttentionBufExecution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto mask = inputs[3];
auto mtbn = static_cast<MetalBackend *>(backend());
auto context = (__bridge MNNMetalContext *)mtbn->context();
auto shape = query->shape();
int seq_len = shape[1];
mNumHead = shape[2];
mHeadDim = shape[3];
mScale = 1.0 / sqrt(mHeadDim);
mIsDecode = seq_len == 1;
if (mCache->mPastLength == 0 || seq_len > 1) {
mCache->mPastLength = seq_len;
}
mCache->mKv_seq_len = mCache->mPastLength;
if(mIsDecode){
mCache->mKv_seq_len = mCache->mPastLength + 1;
}
reallocKVCache();
// Update Parameters
{
auto param = (Param*)mParamQKV.contents;
param->scale = mScale;
param->head_dim = mHeadDim;
param->key_seq_len = mCache->mKv_seq_len;
param->head_num = mNumHead;
param->query_seq_len = seq_len;
}
// For softmax parameter
int inside, outside;
if (mIsDecode) {
inside = mNumHead;
outside = 1;
} else {
inside = 1;
outside = mCache->mKv_seq_len * mNumHead;
}
int axis = mCache->mKv_seq_len;
{
auto softmax = (int*)mParamSoftmax.contents;
// Inside, axis, outside, plane(invalid)
softmax[0] = inside;
softmax[1] = axis;
softmax[2] = outside;
softmax[3] = 0;
}
// Run QK Kernel
{
id<MTLComputePipelineState> pipeline;
if (mIsDecode) {
pipeline = mKernel_qk;
} else {
pipeline = mKernelPrefill_qk;
}
[encoder setComputePipelineState:pipeline];
MetalBackend::setTensor(query, encoder, 0);
MetalBackend::setTensor(key, encoder, 1);
MetalBackend::setTensor(mTempQK.get(), encoder, 2);
MetalBackend::setTensor(mCache->mPastKey.get(), encoder, 3);
MetalBackend::setTensor(mask, encoder, 4);
[encoder setBuffer:mParamQKV offset:0 atIndex:5];
auto gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, mNumHead, mCache->mKv_seq_len)];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
// Run Softmax Kernel
{
[encoder setComputePipelineState:mKernel_softmax];
MetalBackend::setTensor(mTempQK.get(), encoder, 0);
MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1);
[encoder setBuffer:mParamSoftmax offset:0 atIndex:2];
auto gl = [context computeBestGroupAndLocal: mKernel_softmax threads:MTLSizeMake(inside, outside, 1)];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
// Run QKV Kernel
{
id<MTLComputePipelineState> pipeline;
if (mIsDecode) {
pipeline = mKernel_qkv;
} else {
pipeline = mKernelPrefill_qkv;
}
[encoder setComputePipelineState:pipeline];
MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0);
MetalBackend::setTensor(value, encoder, 1);
MetalBackend::setTensor(outputs[0], encoder, 2);
MetalBackend::setTensor(mCache->mPastValue.get(), encoder, 3);
[encoder setBuffer:mParamQKV offset:0 atIndex:4];
auto gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, mNumHead, mHeadDim)];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
// Update status
if(mIsDecode){
mCache->mPastLength += 1;
mCache->mKv_seq_len = mCache->mPastLength + 1;
}
return;
}
class AttentionBufCreator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *> &outputs) const override {
auto param = op->main_as_AttentionParam();
return new AttentionBufExecution(backend, param->kv_cache());
}
};
REGISTER_METAL_OP_TRANSFORMER_CREATOR(AttentionBufCreator, OpType_Attention);
} // namespace MNN
#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */
#endif

View File

@ -137,7 +137,7 @@ public:
* @param creator registering creator.
*/
static void addCreator(OpType type, Creator *creator);
static void setTensor(MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index);
static void setTensor(const MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index);
static std::pair<id<MTLBuffer>, int> getBuffer(MNN::Tensor* tensor);
size_t getTensorSizeInBytes(const Tensor* tensor) const;
virtual bool onSelectDynamicAllocator(int index, int maxIndex) override;
@ -264,5 +264,10 @@ public:
MetalBackend::addCreator(opType, new name); \
}
#define REGISTER_METAL_OP_TRANSFORMER_CREATOR(name, opType) \
void ___##name##__##opType##__() { \
MetalBackend::addCreator(opType, new name); \
}
#endif /* MNN_METAL_ENABLED */
#endif /* MetalBackend_hpp */

View File

@ -5,12 +5,12 @@
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MetalBackend.hpp"
#define MNN_METAL
#import <MNN/MNNSharedContext.h>
#define METAL_CONST_BUFFER_LIMIT 128
#if MNN_METAL_ENABLED
#import <mutex>
#import "backend/metal/MNNMetalContext.h"
#import "core/Macro.h"
#import "core/TensorUtils.hpp"
@ -126,24 +126,30 @@ size_t MetalBackend::getTensorSizeInBytes(const Tensor* tensor) const {
auto format = TensorUtils::getDescribe(tensor)->dimensionFormat;
size_t size;
if (MNN_DATA_FORMAT_NC4HW4 == format && tensor->dimensions() >= 2) {
size_t width = 1;
size_t height = 1;
auto batch = tensor->length(0);
auto channel = tensor->length(1);
int width = 1;
int height = 1;
int batch = tensor->length(0);
int channel = tensor->length(1);
if (tensor->dimensions() >= 3) {
height = tensor->length(2);
}
for (int i=3; i<tensor->dimensions(); ++i) {
width *= tensor->length(i);
}
auto alignC = ROUND_UP(channel, 4);
auto hR = ROUND_UP(height, 4) - height;
int alignC = ROUND_UP(channel, 4);
int hR = ROUND_UP(height, 4) - height;
// width parallel 4, may exceed 3 elements
auto wR = ROUND_UP(width + 3, 4) - width;
int wR = ROUND_UP(width + 3, 4) - width;
int bhw = batch * width * height;
int bhwR = UP_DIV(bhw, 16) * 16 - bhw;
int extraPadding = ALIMAX(bhwR, (hR * width + wR));
size = batch * alignC * width * height;
size = size + hR * width * 4 + wR * 4;
size = size + extraPadding * 4;
} else {
size = tensor->elementSize();
size = 1;
for (int i=0; i<tensor->dimensions(); ++i) {
size *= tensor->length(i);
}
size = ROUND_UP(size, 4);
}
if (0 == size) {
@ -356,7 +362,7 @@ MTLSize getTensorShape(id<MTLBuffer> shape, const Tensor *tensor) {
// shape
((int *)shape.contents)[0] = s;
((int *)shape.contents)[1] = c;
((int *)shape.contents)[2] = z;
((int *)shape.contents)[2] = b;
((int *)shape.contents)[3] = b * z;
// threads
@ -459,13 +465,6 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
auto device = (id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *) (dst->deviceId()))->getBuffer();
auto floats = src->getType().code == halide_type_float;
std::unique_ptr<Tensor> tempSrc;
if (sfmt == dfmt && sfmt == MNN_DATA_FORMAT_NC4HW4) {
tempSrc.reset(new Tensor(src, Tensor::CAFFE));
MNNCPUCopyBuffer(src, tempSrc.get());
src = tempSrc.get();
sfmt = TensorUtils::getDescribe(src)->dimensionFormat;
}
// For command queue from user, need user to make sure last frame's gpu work is ready
bool needWait = !mRuntime->userSync();
// cast
@ -486,7 +485,8 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
};
::memcpy(mShapeH2D.contents, limits, sizeof(limits));
auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto bandwidth = [ctx load: @"downcast_float4" encoder:encoder fp16:mUseFloatAsFp16];
auto pipeline = [ctx pipelineWithName:@"downcast_float4" fp16:mUseFloatAsFp16];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:host offset:0 atIndex:0];
[encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
@ -494,7 +494,7 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
//[ctx dispatchEncoder:encoder threads:{sizeC4, 1, 1} bandwidth:bandwidth];
std::pair<MTLSize, MTLSize> threads;
threads.first = {sizeC4, 1, 1};
threads.second = {bandwidth.maxThreadsPerThreadgroup, 1, 1};
threads.second = {[pipeline maxTotalThreadsPerThreadgroup], 1, 1};
threads.second.width = threads.second.width <= threads.first.width ? threads.second.width : threads.first.width;
threads.first.width = UP_DIV(threads.first.width, threads.second.width);
[encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second];
@ -520,13 +520,14 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Down);
MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt
auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
auto pipeline = [ctx pipelineWithName:kernel fp16:mUseFloatAsFp16];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:buffer offset:0 atIndex:0];
[encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
[encoder setBuffer:mShapeH2D offset:0 atIndex:2];
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
[encoder endEncoding];
commit();
//[ctx wait];
@ -539,16 +540,6 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
auto device = (id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)src->deviceId())->getBuffer();
auto floats = src->getType().code == halide_type_float;
std::shared_ptr<Tensor> tempDst;
if (sfmt == dfmt && sfmt == MNN_DATA_FORMAT_NC4HW4) {
tempDst.reset(new Tensor(dst, Tensor::CAFFE), [dst](void* t) {
auto tensor = (Tensor*)t;
MNNCPUCopyBuffer(tensor, dst);
delete tensor;
});
dst = tempDst.get();
dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
}
// cast
if (sfmt == dfmt || src->dimensions() <= 1) {
if (floats && mUseFloatAsFp16) {
@ -558,7 +549,8 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
NSUInteger size = src->elementSize();
auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto bandwidth = [ctx load: @"upcast_float4" encoder:encoder fp16:mUseFloatAsFp16];
auto pipeline = [ctx pipelineWithName:@"upcast_float4" fp16:mUseFloatAsFp16];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:buffer offset:0 atIndex:1];
auto sizeC4 = UP_DIV(size, 4);
@ -573,7 +565,7 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
//[ctx dispatchEncoder:encoder threads:{sizeC4, 1, 1} bandwidth:bandwidth];
std::pair<MTLSize, MTLSize> threads;
threads.first = {sizeC4, 1, 1};
threads.second = {bandwidth.maxThreadsPerThreadgroup, 1, 1};
threads.second = {[pipeline maxTotalThreadsPerThreadgroup], 1, 1};
threads.second.width = threads.second.width <= threads.first.width ? threads.second.width : threads.first.width;
threads.first.width = UP_DIV(threads.first.width, threads.second.width);
[encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second];
@ -597,18 +589,28 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Up);
MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt
auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
auto pipeline = [ctx pipelineWithName:kernel fp16:mUseFloatAsFp16];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:buffer offset:0 atIndex:1];
[encoder setBuffer:mShapeD2H offset:0 atIndex:2];
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
[encoder endEncoding];
commit();
wait();
memcpy(dst->host<float>(), buffer.contents, dst->size());
}
}
static const char* gCopy = R"metal(
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
kernel void main0(const device int4 *in [[buffer(0)]], device int4 *out [[buffer(1)]], constant uint4& limit [[buffer(2)]], uint gid [[thread_position_in_grid]]) {
if (gid < limit.x) {
out[int(gid)] = in[int(gid)];
}
})metal";
void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const {
auto ctx = (__bridge MNNMetalContext *)context();
@ -619,12 +621,28 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
// copy
if (sfmt == dfmt || src->dimensions() <= 1) {
auto flt = dst->getType().code == halide_type_float;
auto size = flt ? dst->elementSize() : dst->size();
auto bandwidth = [ctx load:flt ? @"copy_float" : @"copy_byte" encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)src->deviceId())->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)dst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
[ctx dispatchEncoder:encoder threads:{(NSUInteger)size, 1, 1} bandwidth:bandwidth];
auto size = dst->usize();
if (mUseFloatAsFp16 && dst->getType().code == halide_type_float) {
size = size / 2;
}
size = UP_DIV(size, (4 * sizeof(float)));
std::vector<std::string> keys = {
"copyC4"
};
id<MTLComputePipelineState> pipeline = mRuntime->findPipeline(keys);
if (nil == pipeline) {
pipeline = makeComputePipelineWithSourceOption(gCopy, "main0", nil);
mRuntime->insertPipeline(keys, pipeline);
}
[encoder setComputePipelineState:pipeline];
if (shape == nil) {
shape = getConstBuffer(4 * sizeof(int));
}
((uint32_t*)[shape contents])[0] = size;
setTensor(src, encoder, 0);
setTensor(dst, encoder, 1);
[encoder setBuffer:shape offset:0 atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(size, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
}
// convert
else {
@ -635,11 +653,13 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
}
auto size = getTensorShape(shape, src);
auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
auto pipeline = [ctx pipelineWithName:kernel fp16:mUseFloatAsFp16];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(src->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(dst->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
[encoder setBuffer:shape offset:0 atIndex:2];
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
if (standalone) {
@ -712,7 +732,7 @@ id<MTLCommandBuffer> MetalBackend::getCommandBufferForNet() const {
return _commandBuffer_net;
}
void MetalBackend::setTensor(MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index) {
void MetalBackend::setTensor(const MNN::Tensor* tensor, id<MTLComputeCommandEncoder> encoder, int index) {
[encoder setBuffer:((MetalRuntimeAllocator::MetalBufferAlloc *)tensor->deviceId())->getBuffer() offset:TensorUtils::getDescribe(tensor)->extra.offset atIndex:index];
}
std::pair<id<MTLBuffer>, int> MetalBackend::getBuffer(MNN::Tensor* tensor) {

View File

@ -18,6 +18,7 @@ def genRegister():
f.write(" namespace MNN {\n")
f.write("#if MNN_METAL_ENABLED\n")
funcs=[]
transformerFuncs = []
for shapath in shaders:
with open(shapath,"r") as sha:
lines=sha.readlines()
@ -27,10 +28,22 @@ def genRegister():
funcname="___"+x[0]+"__"+x[1]+"__();"
funcs.append(funcname)
f.write(" extern void "+funcname+"\n")
elif l.startswith('REGISTER_METAL_OP_TRANSFORMER_CREATOR('):
x=l.replace("REGISTER_METAL_OP_TRANSFORMER_CREATOR(","").replace(")","").replace(" ","").replace(";","").replace("\n","").split(",")
funcname="___"+x[0]+"__"+x[1]+"__();"
transformerFuncs.append(funcname)
f.write("#ifdef MNN_SUPPORT_TRANSFORMER_FUSE\n")
f.write(" extern void "+funcname+"\n")
f.write('#endif\n')
pass
f.write("void registerMetalOps() {\n")
for func in funcs:
f.write(" "+func+"\n")
f.write('#ifdef MNN_SUPPORT_TRANSFORMER_FUSE\n')
for func in transformerFuncs:
f.write(" "+func+"\n")
f.write('#endif\n')
f.write("}\n#endif\n}")
if os.path.isdir(renderPath):
shaders=[]

View File

@ -22,7 +22,7 @@ public:
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
protected:
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private:
MetalConvolution(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias);

View File

@ -175,14 +175,13 @@ ErrorCode MetalConvolution::onResize(const std::vector<Tensor *> &inputs, const
return NO_ERROR;
}
void MetalConvolution::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
auto oc_4 = UP_DIV(output->channel(), 4);
void MetalConvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto input = inputs[0];
auto output = outputs[0];
auto bandwidth = (MetalBandwidth){mPipeline.threadExecutionWidth, mPipeline.maxTotalThreadsPerThreadgroup, NO};
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
MetalBackend::setTensor(input, encoder, 0);
MetalBackend::setTensor(output, encoder, 1);
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
MetalBackend::setTensor(mWeight.get(), encoder, 3);
MetalBackend::setTensor(mBias.get(), encoder, 4);

View File

@ -21,10 +21,9 @@ public:
virtual ~MetalConvolution1x1() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
protected:
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private:
MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, std::shared_ptr<MNN::Tensor> dequantBias, int dequantBits);
MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits);
id<MTLComputePipelineState> mPipeline;
std::pair<MTLSize, MTLSize> mThreads;
};

View File

@ -31,11 +31,10 @@ MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op) :
loadWeight(op->main_as_Convolution2D(), ldInt8Weight);
}
MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, std::shared_ptr<MNN::Tensor> dequantBias, int dequantBits) : MetalConvolutionCommon(backend, op, bias) {
MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> weight, std::shared_ptr<MNN::Tensor> bias, std::shared_ptr<MNN::Tensor> dequantScale, int dequantBits) : MetalConvolutionCommon(backend, op, bias) {
mWeight = weight;
mBias = bias;
mDequantScale = dequantScale;
mDequantZero = dequantBias;
mDequantScaleBias = dequantScale;
mDequantBits = dequantBits;
}
@ -47,7 +46,7 @@ bool MetalConvolution1x1::onClone(Backend* bn, const Op* op, Execution** dst) {
if (nullptr == dst) {
return true;
}
*dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScale, mDequantZero, mDequantBits);
*dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits);
return true;
}
@ -55,15 +54,18 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
MetalConvolutionCommon::onResize(inputs, outputs);
// prepare
// For C4NHW4 format, NHW can be fuse to W
auto input = inputs[0];
auto output = outputs[0];
auto is = input->width() * input->height();
auto ic_4 = UP_DIV(input->channel(), 4);
auto ow = output->width();
auto oh = output->height();
auto os = ow * oh;
auto ob = output->batch();
int is = input->batch();
for (int i=2; i<input->dimensions(); ++i) {
is *= input->length(i);
}
int ic_4 = UP_DIV(input->channel(), 4);
int ow = is;
int oh = 1;
int os = ow;
int ob = 1;
auto oc = output->channel();
auto oc_4 = UP_DIV(output->channel(), 4);
auto backend = static_cast<MetalBackend *>(this->backend());
@ -75,7 +77,7 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
::memcpy(mConstBuffer.contents, constants, sizeof(constants));
MetalRuntime* rt = (MetalRuntime *)backend->runtime();
if (ow == input->width() && oh == input->height() && mDequantScale.get()) {
if (mDequantScaleBias.get()) {
NSUInteger gid_x = UP_DIV(ow * oh, 4);
NSUInteger gid_y = oc_4;
NSUInteger gid_z = ob;
@ -89,8 +91,8 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
(id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(),
mConstBuffer, (((MetalRuntimeAllocator::MetalBufferAlloc *)mWeight->deviceId()))->getBuffer(),
((MetalRuntimeAllocator::MetalBufferAlloc *)mBias->deviceId())->getBuffer(),
(((MetalRuntimeAllocator::MetalBufferAlloc *)mDequantScale->deviceId()))->getBuffer(),
(((MetalRuntimeAllocator::MetalBufferAlloc *)mDequantZero->deviceId()))->getBuffer(), nil];
(((MetalRuntimeAllocator::MetalBufferAlloc *)mDequantScaleBias->deviceId()))->getBuffer(),
nil];
const Tensor* weight = mWeight.get();
const Tensor* bias = mBias.get();
int buffer_offset[] = {
@ -99,8 +101,7 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
0,
TensorUtils::getDescribe(weight)->extra.offset,
TensorUtils::getDescribe(bias)->extra.offset,
TensorUtils::getDescribe(mDequantScale.get())->extra.offset,
TensorUtils::getDescribe(mDequantZero.get())->extra.offset,
TensorUtils::getDescribe(mDequantScaleBias.get())->extra.offset,
0};
MetalRuntime *rt = (MetalRuntime *)backend->runtime();
@ -147,9 +148,8 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
//printf("conv1x1_z4, %d %d %d %d\n", ow, oh, oc_4, ic_4);
}
} else {
NSString* shaderName[] = {@"conv1x1_w4h2", @"conv1x1_w2h2", @"conv1x1_w4h4", @"conv1x1_w2c2", @"conv1x1_w2h2c2"};
int itemW[] = {4, 2, 4, 2, 2};
int itemH[] = {2, 2, 4, 1, 2};
NSString* shaderName[] = {@"conv1x1_g1z8", @"conv1x1_g1z4", @"conv1x1_w4h4", @"conv1x1_w2c2", @"conv1x1_w4c2"};
int itemW[] = {8, 4, 16, 2, 4};
int itemC[] = {4, 4, 4, 8, 8};
int actual_kernel = 5;
if (oc_4 % 2 != 0) {
@ -168,8 +168,8 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
id<MTLComputePipelineState> pipeline = [context pipelineWithName:shaderName[knl_idx] fp16:backend->useFp16InsteadFp32()];
NSUInteger gid_x = UP_DIV(ow, itemW[knl_idx]);
NSUInteger gid_y = UP_DIV(oh, itemH[knl_idx]);
NSUInteger gid_z = ob * UP_DIV(oc, itemC[knl_idx]);
NSUInteger gid_y = UP_DIV(oc, itemC[knl_idx]);
NSUInteger gid_z = 1;
std::string name = [shaderName[knl_idx] UTF8String];
auto ret = [context getGridAndThreadgroup:pipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name offsets:buffer_offset queue:backend->queue()];
@ -188,16 +188,17 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
return NO_ERROR;
}
void MetalConvolution1x1::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
void MetalConvolution1x1::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto input = inputs[0];
auto output = outputs[0];
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
MetalBackend::setTensor(mWeight.get(), encoder, 3);
MetalBackend::setTensor(mBias.get(), encoder, 4);
if (mDequantScale && mDequantZero) {
MetalBackend::setTensor(mDequantScale.get(), encoder, 5);
MetalBackend::setTensor(mDequantZero.get(), encoder, 6);
if (mDequantScaleBias) {
MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 5);
}
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
}

View File

@ -20,12 +20,10 @@ class MetalConvolutionCommon : public MetalExecution {
public:
MetalConvolutionCommon(Backend *backend, const MNN::Op *op, std::shared_ptr<MNN::Tensor> bias);
virtual ~MetalConvolutionCommon() = default;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
protected:
void loadWeight(const MNN::Convolution2D *conv, bool loadWeightInt8 = false);
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) = 0;
virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight = false, bool int4Weight = false);
private:
@ -42,8 +40,7 @@ protected:
std::shared_ptr<MNN::Tensor> mWeight;
std::shared_ptr<MNN::Tensor> mBias;
std::shared_ptr<MNN::Tensor> mDequantScale;
std::shared_ptr<MNN::Tensor> mDequantZero;
std::shared_ptr<MNN::Tensor> mDequantScaleBias;
int mDequantBits;
id<MTLBuffer> mConstBuffer = nil;
};

View File

@ -68,10 +68,6 @@ MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *
}
}
void MetalConvolutionCommon::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
return onFloat(inputs[0], outputs[0], encoder);
}
template <typename FType, typename TType>
void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src, uint8_t* dstOrigion) {
auto goc = oc / group;
@ -101,25 +97,24 @@ void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src,
}
}
static std::vector<std::shared_ptr<MNN::Tensor>> getDequantScale(float* scale, int size, MetalBackend *backend, bool asymmetric) {
static std::shared_ptr<MNN::Tensor> getDequantScale(float* scale, int size, MetalBackend *backend, bool asymmetric) {
int outputCount = 0;
if (asymmetric) {
outputCount = size / 2;
} else {
outputCount = size;
}
std::vector<std::shared_ptr<MNN::Tensor>> scaleBias(2);
std::shared_ptr<MNN::Tensor> dequantScale(MNN::Tensor::createDevice<uint8_t>({outputCount * 4}));
std::shared_ptr<MNN::Tensor> dequantBias(MNN::Tensor::createDevice<uint8_t>({outputCount * 4}));
bool res = backend->onAcquireBuffer(dequantScale.get(), Backend::STATIC) && backend->onAcquireBuffer(dequantBias.get(), Backend::STATIC);
int alignOutputCount = ALIGN_UP4(outputCount);
std::shared_ptr<MNN::Tensor> dequantScale(MNN::Tensor::createDevice<uint8_t>({(int)(alignOutputCount * sizeof(float) * 2)}));
bool res = backend->onAcquireBuffer(dequantScale.get(), Backend::STATIC);
if (!res) {
MNN_ERROR("Buffer allocated error!\n");
return scaleBias;
return nullptr;
}
auto buffer0 = MetalBackend::getBuffer(dequantScale.get());
auto dst_scale = (uint8_t*)[buffer0.first contents] + buffer0.second;
auto buffer1 = MetalBackend::getBuffer(dequantBias.get());
auto dst_bias = (uint8_t*)[buffer1.first contents] + buffer1.second;
::memset(dst_scale, 0, alignOutputCount * 2 * sizeof(float));
auto dst_bias = dst_scale + alignOutputCount * sizeof(float);
for (int o = 0; o < outputCount; ++o) {
float min = 0.0f;
float alpha = 0.0f;
@ -131,10 +126,8 @@ static std::vector<std::shared_ptr<MNN::Tensor>> getDequantScale(float* scale, i
}
((float*)dst_scale)[o] = alpha;
((float*)dst_bias)[o] = min;
}
scaleBias[0] = dequantScale;
scaleBias[1] = dequantBias;
return scaleBias;
}
return dequantScale;
}
void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv, bool loadWeightInt8) {
std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
@ -157,8 +150,7 @@ void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv, bool loa
auto backend = static_cast<MetalBackend *>(this->backend());
mWeight = weightTransform(group, oc, ic, kh, kw, (float*)qnt->weight.get(), !qnt->canUseInt4, qnt->canUseInt4);
auto dequantParams = getDequantScale(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric);
mDequantScale = dequantParams[0];
mDequantZero = dequantParams[1];
mDequantScaleBias = dequantParams;
mDequantBits = qnt->canUseInt4 ? 4:8;
} else if (qnt && qnt->weightFloat.size() > 0) {
mWeight = weightTransform(group, oc, ic, kh, kw, qnt->weightFloat.get(), false, false);

View File

@ -21,7 +21,7 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
protected:
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight=false, bool int4Weight=false) override;
private:
id<MTLComputePipelineState> mPipeline;

View File

@ -82,12 +82,10 @@ ErrorCode MetalConvolutionDepthwise::onResize(const std::vector<Tensor *> &input
return NO_ERROR;
}
void MetalConvolutionDepthwise::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
void MetalConvolutionDepthwise::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset: TensorUtils::getDescribe(input)->extra.offset
atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset: TensorUtils::getDescribe(output)->extra.offset
atIndex:1];
MetalBackend::setTensor(inputs[0], encoder, 0);
MetalBackend::setTensor(outputs[0], encoder, 1);
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
MetalBackend::setTensor(mWeight.get(), encoder, 3);
MetalBackend::setTensor(mBias.get(), encoder, 4);

View File

@ -23,7 +23,7 @@ public:
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
protected:
virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual std::shared_ptr<MNN::Tensor> weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight=false, bool int4Weight=false) override;
private:

View File

@ -141,32 +141,40 @@ ErrorCode MetalConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs
return NO_ERROR;
}
void MetalConvolutionWinograd::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
void MetalConvolutionWinograd::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto input = inputs[0];
auto output = outputs[0];
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
{ // transform
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:1];
auto pipeline = [context pipelineWithName:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" fp16:backend->useFp16InsteadFp32()];
[encoder setComputePipelineState:pipeline];
MetalBackend::setTensor(input, encoder, 0);
MetalBackend::setTensor(mTempSrc.get(), encoder, 1);
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[context dispatchEncoder:encoder threads:mInputTransformThreads bandwidth:bandwidth];
auto gl = [context computeBestGroupAndLocal:pipeline threads:mInputTransformThreads];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
{ // gemm
auto bandwidth = [context load:@"matmul4x4" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:1];
auto pipeline = [context pipelineWithName:@"matmul4x4" fp16:backend->useFp16InsteadFp32()];
[encoder setComputePipelineState:pipeline];
MetalBackend::setTensor(mTempSrc.get(), encoder, 0);
MetalBackend::setTensor(mTempDst.get(), encoder, 1);
MetalBackend::setTensor(mWeight.get(), encoder, 2);
[encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder threads:mMatMulThreads bandwidth:bandwidth];
auto gl = [context computeBestGroupAndLocal:pipeline threads:mMatMulThreads];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
{ // transform
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:0];
auto pipeline = [context pipelineWithName:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" fp16:backend->useFp16InsteadFp32()];
[encoder setComputePipelineState:pipeline];
MetalBackend::setTensor(mTempDst.get(), encoder, 0);
MetalBackend::setTensor(mBias.get(), encoder, 1);
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
MetalBackend::setTensor(output, encoder, 2);
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder threads:mOutputTransformThreads bandwidth:bandwidth];
auto gl = [context computeBestGroupAndLocal:pipeline threads:mOutputTransformThreads];
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
}
}
std::shared_ptr<MNN::Tensor> MetalConvolutionWinograd::weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight, bool int4Weight) {

View File

@ -1,3 +1,11 @@
//
// MetalExecution.hpp
// MNN
//
// Created by MNN on 2023/11/09.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MetalExecution_hpp
#define MetalExecution_hpp

View File

@ -1,3 +1,11 @@
//
// MetalExecution.mm
// MNN
//
// Created by MNN on 2023/11/09.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "MetalExecution.hpp"
#import "backend/metal/MetalBackend.hpp"

View File

@ -1,3 +1,11 @@
//
// MetalLoop.mm
// MNN
//
// Created by MNN on 2023/12/28.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "core/Macro.h"
#import "MetalCast.hpp"
#import "MetalBinary.hpp"

View File

@ -18,7 +18,7 @@ namespace MNN {
class MetalMatMul : public MetalExecution {
public:
MetalMatMul(Backend *backend, const MatMul *matmul);
MetalMatMul(Backend *backend, const MatMul *matmul, bool withBias);
virtual ~MetalMatMul();
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
@ -27,6 +27,8 @@ private:
id<MTLBuffer> mConstBuffer = nil;
bool mTransposeA = false;
bool mTransposeB = false;
id<MTLComputePipelineState> mPipeline;
std::pair<MTLSize, MTLSize> mThreads;
};
} // namespace MNN

View File

@ -18,11 +18,17 @@ struct matP {
int size[4];
int stride[4];
};
MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul) : MetalExecution(backend) {
MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul, bool withBias) : MetalExecution(backend) {
mTransposeA = matmul->transposeA();
mTransposeB = matmul->transposeB();
auto mkbn = static_cast<MetalBackend *>(backend);
mConstBuffer = mkbn->getConstBuffer(sizeof(matP));
auto context = (__bridge MNNMetalContext *)mkbn->context();
if (withBias) {
mPipeline = [context pipelineWithName:@"matmul_bias" fp16:mkbn->useFp16InsteadFp32()];
} else {
mPipeline = [context pipelineWithName:@"matmul" fp16:mkbn->useFp16InsteadFp32()];
}
}
MetalMatMul::~MetalMatMul() {
auto mkbn = static_cast<MetalBackend *>(backend());
@ -59,6 +65,9 @@ ErrorCode MetalMatMul::onResize(const std::vector<Tensor *> &inputs, const std::
}
::memcpy(mConstBuffer.contents, &buffer, sizeof(matP));
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
mThreads = [context computeBestGroupAndLocal:mPipeline threads: MTLSizeMake(h, e, 1)];
return NO_ERROR;
}
@ -71,25 +80,20 @@ void MetalMatMul::onEncode(const std::vector<Tensor *> &inputs, const std::vecto
auto h = C->length(1);
if (inputs.size() > 2) {
auto bandwidth = [context load:@"matmul_bias" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[2]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[2])->extra.offset atIndex:2];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:3];
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input0, encoder, 0);
MetalBackend::setTensor(input1, encoder, 1);
MetalBackend::setTensor(inputs[2], encoder, 2);
MetalBackend::setTensor(output, encoder, 3);
[encoder setBuffer:mConstBuffer offset:0 atIndex:4];
[context dispatchEncoder:encoder
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
bandwidth:bandwidth];
} else {
auto bandwidth = [context load:@"matmul" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input0, encoder, 0);
MetalBackend::setTensor(input1, encoder, 1);
MetalBackend::setTensor(output, encoder, 2);
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
bandwidth:bandwidth];
}
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
}
class MetalMatMulCreator : public MetalBackend::Creator {
@ -100,7 +104,7 @@ public:
MNN_PRINT("metal not support matmul inpt size less than 2\n");
return nullptr;
}
return new MetalMatMul(backend, op->main_as_MatMul());
return new MetalMatMul(backend, op->main_as_MatMul(), inputs.size() > 2);
}
};
REGISTER_METAL_OP_CREATOR(MetalMatMulCreator, OpType_MatMul);

View File

@ -12,13 +12,15 @@
extern void ___MetalEltwiseCreator__OpType_Eltwise__();
extern void ___MetalConvolutionCreator__OpType_Convolution__();
extern void ___MetalLayerNormCreator__OpType_LayerNorm__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
extern void ___AttentionBufCreator__OpType_Attention__();
#endif
extern void ___MetalMatMulCreator__OpType_MatMul__();
extern void ___MetalBinaryCreator__OpType_BinaryOp__();
extern void ___MetalConvolutionDepthwiseCreator__OpType_ConvolutionDepthwise__();
extern void ___MetalDeconvolutionCreator__OpType_Deconvolution__();
extern void ___MetalDeconvolutionCreator__OpType_DeconvolutionDepthwise__();
extern void ___MetalLoopCreator__OpType_While__();
extern void ___MetalReLUCreator__OpType_ReLU__();
extern void ___MetalPoolingCreator__OpType_Pooling__();
extern void ___MetalScaleCreator__OpType_Scale__();
extern void ___MetalInterpCreator__OpType_Interp__();
@ -31,6 +33,7 @@
extern void ___MetalFuseCreator__OpType_Extra__();
extern void ___MetalPReLUCreator__OpType_PReLU__();
extern void ___MetalReLU6Creator__OpType_ReLU6__();
extern void ___MetalReLU6Creator__OpType_ReLU__();
void registerMetalOps() {
___MetalArgMaxCreator__OpType_ArgMax__();
___MetalArgMaxCreator__OpType_ArgMin__();
@ -48,7 +51,6 @@ void registerMetalOps() {
___MetalDeconvolutionCreator__OpType_Deconvolution__();
___MetalDeconvolutionCreator__OpType_DeconvolutionDepthwise__();
___MetalLoopCreator__OpType_While__();
___MetalReLUCreator__OpType_ReLU__();
___MetalPoolingCreator__OpType_Pooling__();
___MetalScaleCreator__OpType_Scale__();
___MetalInterpCreator__OpType_Interp__();
@ -61,6 +63,10 @@ void registerMetalOps() {
___MetalFuseCreator__OpType_Extra__();
___MetalPReLUCreator__OpType_PReLU__();
___MetalReLU6Creator__OpType_ReLU6__();
___MetalReLU6Creator__OpType_ReLU__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
___AttentionBufCreator__OpType_Attention__();
#endif
}
#endif
}

View File

@ -24,6 +24,7 @@ public:
private:
float mSpatialScale;
id<MTLBuffer> mShape;
id<MTLComputePipelineState> mPipeline;
};
} // namespace MNN

View File

@ -16,7 +16,10 @@ namespace MNN {
MetalROIPooling::MetalROIPooling(Backend *backend, float spatialScale)
: MetalExecution(backend), mSpatialScale(spatialScale) {
// nothing to do
auto mtbn = static_cast<MetalBackend *>(backend);
auto context = (__bridge MNNMetalContext *)mtbn->context();
mShape = [context newDeviceBuffer:8 * sizeof(int) + sizeof(float) access:CPUWriteOnly];
mPipeline = [context pipelineWithName:@"ROI_pooling" fp16:mtbn->useFp16InsteadFp32()];
}
ErrorCode MetalROIPooling::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto backend = static_cast<MetalBackend *>(this->backend());
@ -25,16 +28,15 @@ ErrorCode MetalROIPooling::onResize(const std::vector<Tensor *> &inputs, const s
int iw = input->width(), ih = input->height();
int ow = output->width(), oh = output->height(), oz = UP_DIV(output->channel(), 4), ob = output->batch();
auto shape = [context newDeviceBuffer:7 * sizeof(int) + sizeof(float) access:CPUWriteOnly];
((int *)shape.contents)[0] = iw;
((int *)shape.contents)[1] = ih;
((int *)shape.contents)[2] = iw * ih;
((int *)shape.contents)[3] = ow;
((int *)shape.contents)[4] = oh;
((int *)shape.contents)[5] = ow * oh;
((int *)shape.contents)[6] = oz;
((float *)shape.contents)[7] = mSpatialScale;
mShape = shape;
((int *)mShape.contents)[0] = iw;
((int *)mShape.contents)[1] = ih;
((int *)mShape.contents)[2] = iw * ih;
((int *)mShape.contents)[3] = input->batch();
((int *)mShape.contents)[4] = ow;
((int *)mShape.contents)[5] = oh;
((int *)mShape.contents)[6] = ow * oh;
((int *)mShape.contents)[7] = ob;
((float *)mShape.contents)[8] = mSpatialScale;
return NO_ERROR;
}
@ -44,18 +46,22 @@ void MetalROIPooling::onEncode(const std::vector<Tensor *> &inputs, const std::v
auto input = inputs[0], roi = inputs[1], output = outputs[0];
int iw = input->width(), ih = input->height();
int ow = output->width(), oh = output->height(), oz = UP_DIV(output->channel(), 4), ob = output->batch();
auto bandwidth = [context load:@"ROI_pooling" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)roi->deviceId())->getBuffer() offset:TensorUtils::getDescribe(roi)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input, encoder, 0);
MetalBackend::setTensor(roi, encoder, 1);
MetalBackend::setTensor(output, encoder, 2);
[encoder setBuffer:mShape offset:0 atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(ow, 16), UP_DIV(oh, 16), ob * oz) threadsPerThreadgroup:MTLSizeMake(16, 16, 1)];
}
class MetalROIPoolingCreator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
auto roi = inputs[1];
if (TensorUtils::getDescribe(roi)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
// Don't support old roipooling
return nullptr;
}
return new MetalROIPooling(backend, op->main_as_RoiParameters()->spatialScale());
}
};

View File

@ -210,7 +210,7 @@ ErrorCode MetalRaster::onResize(const std::vector<Tensor *> &____inputs, const s
fast = false;
break;
}
if (!OpCommonUtils::canBlitFast(slice, output)) {
if (!OpCommonUtils::canBlitFast(slice, output, 4, true)) {
fast = false;
break;
}
@ -239,7 +239,7 @@ ErrorCode MetalRaster::onResize(const std::vector<Tensor *> &____inputs, const s
for (int v=0; v<iter.second.size(); ++v) {
auto& oldr = des->regions[iter.second[v]];
Tensor::InsideDescribe::Region slice;
OpCommonUtils::turnToPackRegion(oldr, slice, output, 4);
OpCommonUtils::turnToPackRegion(oldr, slice, output, 4, true);
slice.dst.offset /= 4;
slice.src.offset /= 4;
writeSamplerInfo(infoP[v], slice);

View File

@ -1,28 +0,0 @@
//
// MetalReLU.hpp
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MetalReLU_hpp
#define MetalReLU_hpp
#import "MetalExecution.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
class MetalReLU : public MetalExecution {
public:
MetalReLU(Backend *backend, float slope);
virtual ~MetalReLU() = default;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private:
id<MTLBuffer> mSlope;
};
} // namespace MNN
#endif /* MNN_METAL_ENABLED */
#endif /* MetalReLU_hpp */

View File

@ -1,51 +0,0 @@
//
// MetalReLU.mm
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MetalReLU.hpp"
#import "backend/metal/MNNMetalContext.h"
#import "core/Macro.h"
#import "core/Macro.h"
#import "backend/metal/MetalBackend.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
MetalReLU::MetalReLU(Backend *backend, float slope) : MetalExecution(backend) {
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
mSlope = [context newDeviceBuffer:sizeof(float) bytes:&slope access:CPUWriteOnly];
}
void MetalReLU::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0], output = outputs[0];
NSUInteger size = output->elementSize();
auto simd = size % 4 == 0;
if (simd) {
size /= 4;
}
MNN_ASSERT(mSlope.length == sizeof(float));
auto bandwidth = [context load:simd ? @"relu_x4" : @"relu_x1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mSlope offset:0 atIndex:2];
[context dispatchEncoder:encoder threads:{ size, 1, 1 } bandwidth:bandwidth];
MNN_PRINT_ENCODER(context, encoder);
}
class MetalReLUCreator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
return new MetalReLU(backend, op->main_as_Relu()->slope());
}
};
REGISTER_METAL_OP_CREATOR(MetalReLUCreator, OpType_ReLU);
} // namespace MNN
#endif /* MNN_METAL_ENABLED */

View File

@ -16,11 +16,12 @@ namespace MNN {
class MetalReLU6 : public MetalExecution {
public:
MetalReLU6(Backend *backend, float minValue, float maxValue);
MetalReLU6(Backend *backend, float minValue, float maxValue, bool isRelu);
virtual ~MetalReLU6() = default;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private:
id<MTLBuffer> mConst;
id<MTLComputePipelineState> mPipeline;
};
} // namespace MNN

View File

@ -15,33 +15,39 @@
#if MNN_METAL_ENABLED
namespace MNN {
MetalReLU6::MetalReLU6(Backend *backend, float minV, float maxV) : MetalExecution(backend) {
MetalReLU6::MetalReLU6(Backend *backend, float minV, float maxV, bool isRelu) : MetalExecution(backend) {
// For Relu use minV and slope
auto metal = static_cast<MetalBackend *>(backend);
auto context = (__bridge MNNMetalContext *)metal->context();
mConst = [context newDeviceBuffer:4 * sizeof(float) access:CPUWriteOnly];
mConst = [context newDeviceBuffer:4 * sizeof(float) access:CPUWriteOnly];
((float*)mConst.contents)[0] = minV;
((float*)mConst.contents)[1] = maxV;
if (isRelu) {
mPipeline = [context pipelineWithName:@"relu" fp16:metal->useFp16InsteadFp32()];
} else {
mPipeline = [context pipelineWithName:@"relu6" fp16:metal->useFp16InsteadFp32()];
}
}
void MetalReLU6::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0], output = outputs[0];
NSUInteger size = output->elementSize();
auto simd = size % 4 == 0;
if (simd) {
size /= 4;
}
int size = output->elementSize();
size = UP_DIV(size, 4);
((int*)mConst.contents)[2] = size;
((int*)mConst.contents)[3] = size;
auto bandwidth = [context load:simd ? @"relu6_x4" : @"relu6_x1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input, encoder, 0);
MetalBackend::setTensor(output, encoder, 1);
[encoder setBuffer:mConst offset:0 atIndex:2];
[context dispatchEncoder:encoder threads:{ size, 1, 1 } bandwidth:bandwidth];
[encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(size, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
}
class MetalReLU6Creator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
if (op->type() == OpType_ReLU) {
return new MetalReLU6(backend, op->main_as_Relu()->slope(), 0.0f, true);
}
float minV = 0.0f;
float maxV = 6.0f;
if (nullptr != op->main()) {
@ -49,9 +55,10 @@ public:
minV = p->minValue();
maxV = p->maxValue();
}
return new MetalReLU6(backend, minV, maxV);
return new MetalReLU6(backend, minV, maxV, false);
}
};
REGISTER_METAL_OP_CREATOR(MetalReLU6Creator, OpType_ReLU6);
REGISTER_METAL_OP_CREATOR(MetalReLU6Creator, OpType_ReLU);
} // namespace MNN
#endif /* MNN_METAL_ENABLED */

View File

@ -3,7 +3,6 @@
namespace MNN {
void ShaderMap::init() {
mMaps.insert(std::make_pair("shader_MetalReLU6_metal", shader_MetalReLU6_metal));
mMaps.insert(std::make_pair("shader_MetalReLU_metal", shader_MetalReLU_metal));
mMaps.insert(std::make_pair("shader_MetalConvolutionDepthwise_metal", shader_MetalConvolutionDepthwise_metal));
mMaps.insert(std::make_pair("shader_MetalConvolutionActivation_metal", shader_MetalConvolutionActivation_metal));
mMaps.insert(std::make_pair("shader_MetalConvolution_metal", shader_MetalConvolution_metal));

View File

@ -1,38 +1,10 @@
struct tensor_shape {
int size;
int channel;
int slice;
int batch;
int batch_slices;
};
kernel void version_func_002(const device uchar *in [[buffer(0)]],
device uchar *out [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
// do nothing, just for verifying match between mnn and metallib
}
kernel void copy_byte(const device uchar *in [[buffer(0)]],
device uchar *out [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
out[int(gid)] = in[int(gid)];
}
kernel void copy_float(const device ftype *in [[buffer(0)]],
device ftype *out [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
out[int(gid)] = in[int(gid)];
}
kernel void upcast_float(const device ftype *in [[buffer(0)]],
device float *out [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
out[int(gid)] = in[int(gid)];
}
kernel void downcast_float(const device float *in [[buffer(0)]],
device ftype *out [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
out[int(gid)] = in[int(gid)];
}
struct Limit {
uint4 size;
};
@ -55,8 +27,8 @@ kernel void downcast_float4(const device float4 *in [[buffer(0)]],
template <typename IType, typename OType>
static inline void template_NHWC_to_NC4HW4(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
int b = gid.y / s.slice;
int z = gid.y % s.slice;
int b = gid.y % s.batch;
int z = gid.y / s.batch;
int c = z * 4;
auto off_in = in + b * s.size * s.channel + int(gid.x) * s.channel + c;
@ -93,8 +65,8 @@ kernel void cvt_f_NHWC_to_NC4HW4(const device ftype *in [[buffer(0)]],
template <typename IType, typename OType>
static inline void template_NC4HW4_to_NHWC(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
int b = gid.y / s.slice;
int z = gid.y % s.slice;
int b = gid.y % s.batch;
int z = gid.y / s.batch;
int c = z * 4;
auto off_in = in + int(gid.y) * s.size + int(gid.x);
auto off_out = out + b * s.size * s.channel + int(gid.x) * s.channel + c;
@ -132,8 +104,8 @@ kernel void cvt_f_NC4HW4_to_NHWC(const device ftype4 *in [[buffer(0)]],
template <typename IType, typename OType>
static inline void template_NCHW_to_NC4HW4(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
int b = gid.y / s.slice;
int z = gid.y % s.slice;
int b = gid.y % s.batch;
int z = gid.y / s.batch;
int c = z * 4;
auto off_in = in + (b * s.channel + c) * s.size + int(gid.x);
@ -170,8 +142,8 @@ kernel void cvt_f_NCHW_to_NC4HW4(const device ftype *in [[buffer(0)]],
template <typename IType, typename OType>
static inline void template_NC4HW4_to_NCHW(const device IType *in, device OType *out, constant tensor_shape &s, uint2 gid) {
int b = gid.y / s.slice;
int z = gid.y % s.slice;
int b = gid.y % s.batch;
int z = gid.y / s.batch;
int c = z * 4;
auto off_in = in + int(gid.y) * s.size + int(gid.x);
@ -210,8 +182,8 @@ kernel void cvt_f_NC4HW4_to_NCHW(const device ftype4 *in [[buffer(0)]],
template<typename IType, typename OType>
static inline void template_NHWC_to_NCHW(const device IType* in,
device OType* out, constant tensor_shape &s, uint2 gid) {
int b = gid.y / s.slice;
int c4 = gid.y % s.slice;
int b = gid.y % s.batch;
int c4 = gid.y / s.batch;
auto in_off = (b * s.size + gid.x) * s.channel + c4 * 4;
auto out_off = (b * s.channel + c4 * 4) * s.size + gid.x;
@ -238,8 +210,8 @@ kernel void upcast_f_NHWC_to_NCHW(const device ftype *in [[buffer(0)]],
template<typename IType, typename OType>
static inline void template_NCHW_to_NHWC(const device IType* in,
device OType* out, constant tensor_shape &s, uint2 gid) {
int b = gid.y / s.slice;
int c4 = gid.y % s.slice;
int b = gid.y % s.batch;
int c4 = gid.y / s.batch;
auto in_off = (b * s.channel + c4 * 4) * s.size + gid.x;
auto out_off = (b * s.size + gid.x) * s.channel + c4 * 4;

View File

@ -102,9 +102,9 @@ kernel void conv(const device ftype4 *in [[buffer(0)]],
offset_x += sx * cst.dilation_x;
offset_y += sy * cst.dilation_y;
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_wt = wt + idx_c * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
auto z_out = out + idx_b * cst.output_slice * cst.output_size + (int)idx_c * cst.output_size + (int)gid.y * cst.output_width + (int)gid.x;
auto z_out = out + idx_b * cst.output_size + (int)idx_c * cst.batch * cst.output_size + (int)gid.y * cst.output_width + (int)gid.x;
int dilation_h = cst.input_width * cst.dilation_y;
FLOAT4 result = FLOAT4(biasTerms[idx_c]);
@ -112,7 +112,7 @@ kernel void conv(const device ftype4 *in [[buffer(0)]],
for (auto y = 0; y < kh; y++) {
for (auto x = 0; x < kw; x++) {
auto wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x + x];
auto in4 = z_in[z * cst.input_size + y * dilation_h + x * cst.dilation_x];
auto in4 = z_in[z * cst.input_size * cst.batch + y * dilation_h + x * cst.dilation_x];
result += FLOAT4(in4 * wt4);
}
}
@ -141,15 +141,15 @@ kernel void convk3s1d1p1_w2z4(const device ftype4 *in [[buffer(0)]],
int offset_x = (int)gid.x * 2 - cst.pad_x;
int offset_y = (int)gid.y - cst.pad_y;
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_flt = wt + uz[0] * cst.input_slice * cst.kernel_size;
auto z_out = out + idx_b * cst.output_slice * cst.output_size + uz[0] * cst.output_size + idx_h * cst.output_width + idx_w;
auto z_out = out + idx_b * cst.output_size + uz[0] * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
int ws = cst.input_slice * cst.kernel_size;
FLOAT4 result0 = 0, result1 = 0, result2 = 0, result3 = 0;
FLOAT4 result4 = 0, result5 = 0, result6 = 0, result7 = 0;
for (auto z = 0; z < cst.input_slice; z++, z_flt += cst.kernel_size, z_in += cst.input_size) {
for (auto z = 0; z < cst.input_slice; z++, z_flt += cst.kernel_size, z_in += (cst.input_size * cst.batch)) {
auto in00 = (offset_x<0 || offset_y<0) ? (ftype4)0.f : *(z_in+0*cst.input_width+0);
auto in01 = (offset_x+1>=cst.input_width || offset_y<0) ? (ftype4)0.f : *(z_in+0*cst.input_width+1);
auto in02 = (offset_x+2>=cst.input_width || offset_y<0) ? (ftype4)0.f : *(z_in+0*cst.input_width+2);
@ -228,19 +228,19 @@ kernel void conv_s1d1p0_w2(const device ftype4 *in [[buffer(0)]],
bool valid = (idx_w + 1 < cst.output_width);
auto z_in = in + idx_b * cst.input_slice * cst.input_size + idx_h * cst.input_width + idx_w;
auto z_in = in + idx_b * cst.input_size + idx_h * cst.input_width + idx_w;
auto z_wt = wt + idx_c * cst.input_slice * cst.kernel_size;
auto z_out = out + idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto z_out = out + idx_b * cst.output_size + idx_c * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
FLOAT4 result0 = FLOAT4(biasTerms[idx_c]);
FLOAT4 result1 = result0;
for (auto z = 0; z < cst.input_slice; z++) {
for (auto y = 0; y < cst.kernel_y; y++) {
auto wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x];
auto in4_0 = z_in[z * cst.input_size + y * cst.input_width];
auto in4_0 = z_in[z * cst.batch * cst.input_size + y * cst.input_width];
result0 += FLOAT4(in4_0 * wt4);
for (auto x = 1; x < cst.kernel_x; x++) {
in4_0 = z_in[z * cst.input_size + y * cst.input_width + x];
in4_0 = z_in[z * cst.batch * cst.input_size + y * cst.input_width + x];
result1 += FLOAT4(in4_0 * wt4);
wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x + x];
result0 += FLOAT4(in4_0 * wt4);
@ -271,9 +271,9 @@ kernel void conv_s1d1p0_w4(const device ftype4 *in [[buffer(0)]],
int3 uz = idx_w + int3(1, 2, 3);
bool3 valids = uz.xyz < cst.output_width;
auto z_in = in + idx_b * cst.input_slice * cst.input_size + idx_h * cst.input_width + idx_w;
auto z_in = in + idx_b * cst.input_size + idx_h * cst.input_width + idx_w;
auto z_wt = wt + idx_c * cst.input_slice * cst.kernel_size;
auto z_out = out + idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto z_out = out + idx_b * cst.output_size + idx_c * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
FLOAT4 result0 = FLOAT4(biasTerms[idx_c]);
FLOAT4 result1 = result0;
@ -286,7 +286,7 @@ kernel void conv_s1d1p0_w4(const device ftype4 *in [[buffer(0)]],
auto wt4_1 = wt_base[1];
auto wt4_2 = wt_base[2];
auto z_in_base = z_in + z * cst.input_size + y * cst.input_width;
auto z_in_base = z_in + z * cst.batch * cst.input_size + y * cst.input_width;
auto in4_0 = z_in_base[0];
result0 += FLOAT4(in4_0 * wt4_0);
@ -346,14 +346,14 @@ kernel void conv_z4(const device ftype4 *in [[buffer(0)]],
offset_x += sx * cst.dilation_x;
offset_y += sy * cst.dilation_y;
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_wt = wt + uz[0] * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
auto z_out = out + idx_b * cst.output_slice * cst.output_size + uz[0] * cst.output_size + idx_h * cst.output_width + idx_w;
auto z_out = out + idx_b * cst.output_size + uz[0] * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
int ws = cst.input_slice * cst.kernel_size;
int dilation_h = cst.input_width * cst.dilation_y;
FLOAT4 result0 = 0, result1 = 0, result2 = 0, result3 = 0;
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size) {
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size * cst.batch) {
for (auto y = 0; y < kh; y++) {
for (auto x = 0; x < kw; x++) {
auto x_wt = z_wt + y * cst.kernel_x + x;
@ -400,14 +400,14 @@ kernel void conv_z2(const device ftype4 *in [[buffer(0)]],
offset_x += sx * cst.dilation_x;
offset_y += sy * cst.dilation_y;
auto z_in = in + idx_b * cst.input_slice * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_in = in + idx_b * cst.input_size + offset_y * cst.input_width + offset_x;
auto z_wt = wt + uz[0] * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
auto z_out = out + idx_b * cst.output_slice * cst.output_size + uz[0] * cst.output_size + idx_h * cst.output_width + idx_w;
auto z_out = out + idx_b * cst.output_size + uz[0] * cst.batch * cst.output_size + idx_h * cst.output_width + idx_w;
int ws = cst.input_slice * cst.kernel_size;
int dilation_h = cst.input_width * cst.dilation_y;
FLOAT4 result0 = 0, result1 = 0;
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size) {
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size * cst.batch) {
for (auto y = 0; y < kh; y++) {
for (auto x = 0; x < kw; x++) {
auto x_wt = z_wt + y * cst.kernel_x + x;
@ -418,5 +418,5 @@ kernel void conv_z2(const device ftype4 *in [[buffer(0)]],
}
}
/* true */ *z_out = activate(ftype4(result0 + FLOAT4(biasTerms[uz[0]])), cst.activation);
if (valids) { z_out += cst.output_size; *z_out = activate(ftype4(result1 + FLOAT4(biasTerms[uz[1]])), cst.activation); }
if (valids) { z_out += cst.output_size * cst.batch; *z_out = activate(ftype4(result1 + FLOAT4(biasTerms[uz[1]])), cst.activation); }
}

View File

@ -13,36 +13,6 @@ struct conv1x1_constants {
conv_activation_type activation;
};
kernel void conv1x1_w1h1(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant conv1x1_constants& cst [[buffer(2)]],
const device ftype4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
int idx_w = gid.x;
int idx_h = gid.y;
int idx_c = gid.z % cst.output_slice;
int idx_b = gid.z / cst.output_slice;
auto xy_wt = wt + idx_c * cst.input_slice;
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto biasValue = FLOAT4(biasTerms[idx_c]);
FLOAT4 result0 = biasValue;
for (auto z = 0; z < cst.input_slice; z++) {
auto in40 = xy_in0[0];
auto w = xy_wt[z];
result0 += FLOAT4(in40 * w);
xy_in0 += cst.input_size;
}
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
}
kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant conv1x1_constants& cst [[buffer(2)]],
@ -54,8 +24,8 @@ kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]],
int rx = gid.x * CONV_UNROLL;
int uz = gid.y;
auto xy_wt = wt + uz * cst.input_slice;
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.output_size * cst.batch + rx;
auto biasValue = FLOAT4(biasTerms[uz]);
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
int computeSize = min(cst.output_size - rx, CONV_UNROLL);
@ -71,7 +41,7 @@ kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]],
result1 += FLOAT4(in41 * w);
result2 += FLOAT4(in42 * w);
result3 += FLOAT4(in43 * w);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
@ -86,20 +56,19 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in [[buffer(0)]],
const device MNN::char4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
const device float4 *dequantScale [[buffer(5)]],
const device float4 *dequantBias [[buffer(6)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;
int rx = gid.x * CONV_UNROLL;
int uz = gid.y;
auto xy_wt = wt + uz * cst.input_slice;
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.output_size * cst.batch + rx;
auto biasValue = FLOAT4(biasTerms[uz]);
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
int computeSize = min(cst.output_size - rx, CONV_UNROLL);
auto scale = FLOAT4(dequantScale[uz]);
auto dequant_bias = FLOAT4(dequantBias[uz]);
auto dequant_bias = FLOAT4(dequantScale[uz + cst.output_slice]);
for (auto z = 0; z < cst.input_slice; z++) {
auto in40 = (FLOAT4)*xy_in0;
@ -112,20 +81,14 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in [[buffer(0)]],
FLOAT4x4 w_fp32 = FLOAT4x4(FLOAT4(w[0]), FLOAT4(w[1]), FLOAT4(w[2]), FLOAT4(w[3]));
FLOAT4x4 w_dequant;
for (int i = 0; i < 4; ++i) {
FLOAT4 w4 = w_fp32[i];
FLOAT4 res;
for (int j = 0; j < 4; ++j) {
float wf = w4[j] * scale[i] + dequant_bias[i];
res[j] = wf;
}
w_dequant[i] = res;
w_dequant[i] = w_fp32[i] * scale[i] + dequant_bias[i];
}
result0 += FLOAT4(in40 * w_dequant);
result1 += FLOAT4(in41 * w_dequant);
result2 += FLOAT4(in42 * w_dequant);
result3 += FLOAT4(in43 * w_dequant);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
/* true */
@ -141,20 +104,19 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]],
const device MNN::uchar4x2 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
const device float4 *dequantScale [[buffer(5)]],
const device float4 *dequantBias [[buffer(6)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;
int rx = gid.x * CONV_UNROLL;
int uz = gid.y;
auto xy_wt = wt + uz * cst.input_slice;
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.output_size * cst.batch + rx;
auto biasValue = FLOAT4(biasTerms[uz]);
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
int computeSize = min(cst.output_size - rx, CONV_UNROLL);
auto scale = FLOAT4(dequantScale[uz]);
auto dequant_bias = FLOAT4(dequantBias[uz]);
auto dequant_bias = FLOAT4(dequantScale[uz + cst.output_slice]);
for (auto z = 0; z < cst.input_slice; z++) {
auto in40 = (FLOAT4)*xy_in0;
@ -169,11 +131,7 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]],
for (int i = 0; i < 4; ++i) {
// ftype4 w4 = ftype4(w_fp32[i]);
FLOAT4 w4 = FLOAT4((float)(w_int4[i][0] >> 4) - 8, (float)(w_int4[i][0] & 15) - 8, (float)(w_int4[i][1] >> 4) - 8, (float)(w_int4[i][1] & 15) - 8);
FLOAT4 res;
for (int j = 0; j < 4; ++j) {
float wf = w4[j] * scale[i] + dequant_bias[i];
res[j] = wf;
}
FLOAT4 res = w4 * scale[i] + dequant_bias[i];
w_dequant[i] = res;
}
@ -181,7 +139,7 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]],
result1 += FLOAT4(in41 * w_dequant);
result2 += FLOAT4(in42 * w_dequant);
result3 += FLOAT4(in43 * w_dequant);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
/* true */
@ -213,9 +171,9 @@ kernel void conv1x1_g1z8(const device ftype4 *in [[buffer(0)]],
int rx = gid.x * CONV_UNROLL_L;
int uz = gid.y;
auto xy_wt = wt + uz * cst.input_slice;
auto xy_in0 = in + (int)gid.z * cst.input_slice * cst.input_size + rx + 0;
auto xy_in0 = in + (int)gid.z * cst.input_size + rx + 0;
auto xy_out = out + (int)gid.z * cst.output_slice * cst.output_size + uz * cst.output_size + rx;
auto xy_out = out + (int)gid.z * cst.output_size + uz * cst.batch * cst.output_size + rx;
auto biasValue = FLOAT4(biasTerms[uz]);
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
FLOAT4 result4 = biasValue, result5 = biasValue, result6 = biasValue, result7 = biasValue;
@ -241,7 +199,7 @@ kernel void conv1x1_g1z8(const device ftype4 *in [[buffer(0)]],
result5 += FLOAT4(in45 * w);
result6 += FLOAT4(in46 * w);
result7 += FLOAT4(in47 * w);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
@ -254,84 +212,23 @@ kernel void conv1x1_g1z8(const device ftype4 *in [[buffer(0)]],
if (computeSize > 7) {xy_out[7] = activate(ftype4(result7), cst.activation); }
}
kernel void conv1x1_w4h2(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant conv1x1_constants& cst [[buffer(2)]],
const device ftype4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * 4 >= cst.output_width || (int)gid.y * 2 >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
int idx_w = gid.x << 2;
int idx_h = gid.y << 1;
int idx_c = gid.z % cst.output_slice;
int idx_b = gid.z / cst.output_slice;
auto xy_wt = wt + idx_c * cst.input_slice;
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto biasValue = FLOAT4(biasTerms[idx_c]);
FLOAT4 result0 = biasValue, result1 = biasValue, result2 = biasValue, result3 = biasValue;
FLOAT4 result4 = biasValue, result5 = biasValue, result6 = biasValue, result7 = biasValue;
for (auto z = 0; z < cst.input_slice; z++) {
auto in40 = xy_in0[0];
auto in41 = xy_in0[1];
auto in42 = xy_in0[2];
auto in43 = xy_in0[3];
auto in44 = xy_in0[cst.output_width+0];
auto in45 = xy_in0[cst.output_width+1];
auto in46 = xy_in0[cst.output_width+2];
auto in47 = xy_in0[cst.output_width+3];
auto w = xy_wt[z];
result0 += FLOAT4(in40 * w);
result1 += FLOAT4(in41 * w);
result2 += FLOAT4(in42 * w);
result3 += FLOAT4(in43 * w);
result4 += FLOAT4(in44 * w);
result5 += FLOAT4(in45 * w);
result6 += FLOAT4(in46 * w);
result7 += FLOAT4(in47 * w);
xy_in0 += cst.input_size;
}
int widthSize = min(cst.output_width - idx_w, 4);
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
if (widthSize > 1) {xy_out[1] = activate(ftype4(result1), cst.activation); }
if (widthSize > 2) {xy_out[2] = activate(ftype4(result2), cst.activation); }
if (widthSize > 3) {xy_out[3] = activate(ftype4(result3), cst.activation); }
int heightSize = min(cst.output_height - idx_h, 2);
if(heightSize > 1) {
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result4), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result5), cst.activation); }
if (widthSize > 2) {xy_out[cst.output_width+2] = activate(ftype4(result6), cst.activation); }
if (widthSize > 3) {xy_out[cst.output_width+3] = activate(ftype4(result7), cst.activation); }
}
}
kernel void conv1x1_w4h4(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant conv1x1_constants& cst [[buffer(2)]],
const device ftype4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * 4 >= cst.output_width || (int)gid.y * 4 >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
if ((int)gid.x * 16 >= cst.output_width || (int)gid.y >= cst.batch * cst.output_slice) return;
int idx_w = gid.x << 2;
int idx_h = gid.y << 2;
int idx_c = gid.z % cst.output_slice;
int idx_b = gid.z / cst.output_slice;
int idx_w = gid.x << 4;
int idx_h = 0;
int idx_c = gid.y / cst.batch;
int idx_b = gid.y % cst.batch;
auto xy_wt = wt + idx_c * cst.input_slice;
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_in0 = in + (int)idx_b * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_size + idx_c * cst.output_size * cst.batch + idx_h * cst.output_width + idx_w;
auto biasValue = FLOAT4(biasTerms[idx_c]);
FLOAT4 result00 = biasValue, result01 = biasValue, result02 = biasValue, result03 = biasValue;
FLOAT4 result10 = biasValue, result11 = biasValue, result12 = biasValue, result13 = biasValue;
@ -343,19 +240,19 @@ kernel void conv1x1_w4h4(const device ftype4 *in [[buffer(0)]],
auto in01 = xy_in0[1];
auto in02 = xy_in0[2];
auto in03 = xy_in0[3];
auto in10 = xy_in0[cst.output_width+0];
auto in11 = xy_in0[cst.output_width+1];
auto in12 = xy_in0[cst.output_width+2];
auto in13 = xy_in0[cst.output_width+3];
auto in10 = xy_in0[4];
auto in11 = xy_in0[5];
auto in12 = xy_in0[6];
auto in13 = xy_in0[7];
auto in20 = xy_in0[cst.output_width+cst.output_width+0];
auto in21 = xy_in0[cst.output_width+cst.output_width+1];
auto in22 = xy_in0[cst.output_width+cst.output_width+2];
auto in23 = xy_in0[cst.output_width+cst.output_width+3];
auto in30 = xy_in0[cst.output_width+cst.output_width+cst.output_width+0];
auto in31 = xy_in0[cst.output_width+cst.output_width+cst.output_width+1];
auto in32 = xy_in0[cst.output_width+cst.output_width+cst.output_width+2];
auto in33 = xy_in0[cst.output_width+cst.output_width+cst.output_width+3];
auto in20 = xy_in0[8];
auto in21 = xy_in0[9];
auto in22 = xy_in0[10];
auto in23 = xy_in0[11];
auto in30 = xy_in0[12];
auto in31 = xy_in0[13];
auto in32 = xy_in0[14];
auto in33 = xy_in0[15];
auto w = xy_wt[z];
@ -378,34 +275,26 @@ kernel void conv1x1_w4h4(const device ftype4 *in [[buffer(0)]],
result32 += FLOAT4(in32 * w);
result33 += FLOAT4(in33 * w);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
int widthSize = min(cst.output_width - idx_w, 4);
int widthSize = min(cst.output_width - idx_w, 16);
/* true */ *xy_out = activate(ftype4(result00), cst.activation);
if (widthSize > 1) {xy_out[1] = activate(ftype4(result01), cst.activation); }
if (widthSize > 2) {xy_out[2] = activate(ftype4(result02), cst.activation); }
if (widthSize > 3) {xy_out[3] = activate(ftype4(result03), cst.activation); }
int heightSize = min(cst.output_height - idx_h, 4);
if(heightSize > 1) {
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result10), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result11), cst.activation); }
if (widthSize > 2) {xy_out[cst.output_width+2] = activate(ftype4(result12), cst.activation); }
if (widthSize > 3) {xy_out[cst.output_width+3] = activate(ftype4(result13), cst.activation); }
}
if(heightSize > 2) {
/* true */ {xy_out[cst.output_width+cst.output_width+0] = activate(ftype4(result20), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_width+cst.output_width+1] = activate(ftype4(result21), cst.activation); }
if (widthSize > 2) {xy_out[cst.output_width+cst.output_width+2] = activate(ftype4(result22), cst.activation); }
if (widthSize > 3) {xy_out[cst.output_width+cst.output_width+3] = activate(ftype4(result23), cst.activation); }
}
if(heightSize > 3) {
/* true */ {xy_out[cst.output_width+cst.output_width+cst.output_width+0] = activate(ftype4(result30), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_width+cst.output_width+cst.output_width+1] = activate(ftype4(result31), cst.activation); }
if (widthSize > 2) {xy_out[cst.output_width+cst.output_width+cst.output_width+2] = activate(ftype4(result32), cst.activation); }
if (widthSize > 3) {xy_out[cst.output_width+cst.output_width+cst.output_width+3] = activate(ftype4(result33), cst.activation); }
}
if (widthSize > 4) {xy_out[4] = activate(ftype4(result10), cst.activation); }
if (widthSize > 5) {xy_out[5] = activate(ftype4(result11), cst.activation); }
if (widthSize > 6) {xy_out[6] = activate(ftype4(result12), cst.activation); }
if (widthSize > 7) {xy_out[7] = activate(ftype4(result13), cst.activation); }
if (widthSize > 8) {xy_out[8] = activate(ftype4(result20), cst.activation); }
if (widthSize > 9) {xy_out[9] = activate(ftype4(result21), cst.activation); }
if (widthSize > 10) {xy_out[10] = activate(ftype4(result22), cst.activation); }
if (widthSize > 11) {xy_out[11] = activate(ftype4(result23), cst.activation); }
if (widthSize > 12) {xy_out[12] = activate(ftype4(result30), cst.activation); }
if (widthSize > 13) {xy_out[13] = activate(ftype4(result31), cst.activation); }
if (widthSize > 14) {xy_out[14] = activate(ftype4(result32), cst.activation); }
if (widthSize > 15) {xy_out[15] = activate(ftype4(result33), cst.activation); }
}
@ -415,19 +304,19 @@ kernel void conv1x1_w2c2(const device ftype4 *in [[buffer(0)]],
const device ftype4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z * 2 >= cst.batch * cst.output_slice) return;
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y * 2 >= cst.batch * cst.output_slice) return;
int channel_pack = (cst.output_channel + 7) >> 3;
int idx_w = gid.x << 1;
int idx_h = gid.y;
int idx_c = (gid.z % channel_pack) << 1;
int idx_b = gid.z / channel_pack;
int idx_h = 0;
int idx_c = (gid.y % channel_pack) << 1;
int idx_b = gid.y / channel_pack;
if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;
auto xy_wt = wt + idx_c * cst.input_slice;
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_in0 = in + (int)idx_b * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_size + idx_c * cst.output_size * cst.batch + idx_h * cst.output_width + idx_w;
auto biasValue0 = FLOAT4(biasTerms[idx_c]);
auto biasValue1 = FLOAT4(biasTerms[idx_c+1]);
@ -445,7 +334,7 @@ kernel void conv1x1_w2c2(const device ftype4 *in [[buffer(0)]],
result1 += FLOAT4(in41 * w0);
result4 += FLOAT4(in40 * w1);
result5 += FLOAT4(in41 * w1);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
int widthSize = min(cst.output_width - idx_w, 2);
@ -454,79 +343,30 @@ kernel void conv1x1_w2c2(const device ftype4 *in [[buffer(0)]],
int channelSize = min(cst.output_slice - idx_c, 2);
if(channelSize > 1) {
/* true */ {xy_out[cst.output_size+0] = activate(ftype4(result4), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_size+1] = activate(ftype4(result5), cst.activation); }
/* true */ {xy_out[cst.output_size * cst.batch +0] = activate(ftype4(result4), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_size * cst.batch +1] = activate(ftype4(result5), cst.activation); }
}
}
kernel void conv1x1_w2h2(const device ftype4 *in [[buffer(0)]],
kernel void conv1x1_w4c2(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant conv1x1_constants& cst [[buffer(2)]],
const device ftype4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y * 2 >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
int idx_w = gid.x << 1;
int idx_h = gid.y << 1;
int idx_c = gid.z % cst.output_slice;
int idx_b = gid.z / cst.output_slice;
auto xy_wt = wt + idx_c * cst.input_slice;
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto biasValue = FLOAT4(biasTerms[idx_c]);
FLOAT4 result0 = biasValue, result1 = biasValue;
FLOAT4 result4 = biasValue, result5 = biasValue;
for (auto z = 0; z < cst.input_slice; z++) {
auto in40 = xy_in0[0];
auto in41 = xy_in0[1];
auto in44 = xy_in0[cst.output_width+0];
auto in45 = xy_in0[cst.output_width+1];
auto w = xy_wt[z];
result0 += FLOAT4(in40 * w);
result1 += FLOAT4(in41 * w);
result4 += FLOAT4(in44 * w);
result5 += FLOAT4(in45 * w);
xy_in0 += cst.input_size;
}
int widthSize = min(cst.output_width - idx_w, 2);
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
if (widthSize > 1) {xy_out[1] = activate(ftype4(result1), cst.activation); }
int heightSize = min(cst.output_height - idx_h, 2);
if(heightSize > 1) {
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result4), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result5), cst.activation); }
}
}
kernel void conv1x1_w2h2c2(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant conv1x1_constants& cst [[buffer(2)]],
const device ftype4x4 *wt [[buffer(3)]],
const device ftype4 *biasTerms [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x * 2 >= cst.output_width || (int)gid.y * 2 >= cst.output_height || (int)gid.z * 2 >= cst.batch * cst.output_slice) return;
if ((int)gid.x * 4 >= cst.output_width || (int)gid.y * 2 >= cst.batch * cst.output_slice) return;
int channel_pack = (cst.output_channel + 7) >> 3;
int idx_w = gid.x << 1;
int idx_h = gid.y << 1;
int idx_c = (gid.z % channel_pack) << 1;
int idx_b = gid.z / channel_pack;
int idx_w = gid.x << 2;
int idx_h = 0;
int idx_c = (gid.y % channel_pack) << 1;
int idx_b = gid.y / channel_pack;
if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;
auto xy_wt = wt + idx_c * cst.input_slice;
auto xy_in0 = in + (int)idx_b * cst.input_slice * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_in0 = in + (int)idx_b * cst.input_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_slice * cst.output_size + idx_c * cst.output_size + idx_h * cst.output_width + idx_w;
auto xy_out = out + (int)idx_b * cst.output_size + idx_c * cst.output_size * cst.batch + idx_h * cst.output_width + idx_w;
auto biasValue0 = FLOAT4(biasTerms[idx_c]);
auto biasValue1 = FLOAT4(biasTerms[idx_c+1]);
@ -537,8 +377,8 @@ kernel void conv1x1_w2h2c2(const device ftype4 *in [[buffer(0)]],
for (auto z = 0; z < cst.input_slice; z++) {
auto in40 = xy_in0[0];
auto in41 = xy_in0[1];
auto in44 = xy_in0[cst.output_width+0];
auto in45 = xy_in0[cst.output_width+1];
auto in44 = xy_in0[2];
auto in45 = xy_in0[3];
auto w0 = xy_wt[z];
auto w1 = xy_wt[cst.input_slice+z];
@ -551,27 +391,20 @@ kernel void conv1x1_w2h2c2(const device ftype4 *in [[buffer(0)]],
result3 += FLOAT4(in41 * w1);
result6 += FLOAT4(in44 * w1);
result7 += FLOAT4(in45 * w1);
xy_in0 += cst.input_size;
xy_in0 += cst.input_size * cst.batch;
}
int widthSize = min(cst.output_width - idx_w, 2);
int widthSize = min(cst.output_width - idx_w, 4);
/* true */ *xy_out = activate(ftype4(result0), cst.activation);
if (widthSize > 1) {xy_out[1] = activate(ftype4(result1), cst.activation); }
int heightSize = min(cst.output_height - idx_h, 2);
if(heightSize > 1) {
/* true */ {xy_out[cst.output_width+0] = activate(ftype4(result4), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_width+1] = activate(ftype4(result5), cst.activation); }
}
if (widthSize > 2) {xy_out[2] = activate(ftype4(result4), cst.activation); }
if (widthSize > 3) {xy_out[3] = activate(ftype4(result5), cst.activation); }
int channelSize = min(cst.output_slice - idx_c, 2);
if(channelSize > 1) {
/* true */ xy_out[cst.output_size] = activate(ftype4(result2), cst.activation);
if (widthSize > 1) {xy_out[cst.output_size+1] = activate(ftype4(result3), cst.activation); }
if(heightSize > 1) {
/* true */ {xy_out[cst.output_size+cst.output_width+0] = activate(ftype4(result6), cst.activation); }
if (widthSize > 1) {xy_out[cst.output_size+cst.output_width+1] = activate(ftype4(result7), cst.activation); }
}
/* true */ xy_out[cst.output_size * cst.batch] = activate(ftype4(result2), cst.activation);
if (widthSize > 1) {xy_out[cst.output_size * cst.batch +1] = activate(ftype4(result3), cst.activation); }
if (widthSize > 2) {xy_out[cst.output_size * cst.batch +2] = activate(ftype4(result6), cst.activation); }
if (widthSize > 3) {xy_out[cst.output_size * cst.batch +3] = activate(ftype4(result7), cst.activation); }
}
}

View File

@ -28,7 +28,7 @@ kernel void conv_depthwise(const device ftype4 *in [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice * cst.batch) return;
int oz = gid.z % cst.slice;
int oz = gid.z / cst.batch;
int offset_x = (int)gid.x * cst.stride_x - cst.pad_x;
int offset_y = (int)gid.y * cst.stride_y - cst.pad_y;
int sx = max(0, (UP_DIV(-offset_x, cst.dilation_x)));

View File

@ -36,8 +36,8 @@ kernel void deconv(const device ftype4 *in [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
int b = gid.z / cst.output_slice;
int o = gid.z % cst.output_slice;
int b = gid.z % cst.batch;
int o = gid.z / cst.batch;
FLOAT4 result = FLOAT4(biasTerms[o]);
int oy = (int)gid.y + cst.pad_y;
@ -56,12 +56,12 @@ kernel void deconv(const device ftype4 *in [[buffer(0)]],
int min_ix = (ox - max_kx * cst.dilation_x) / cst.stride_x;
auto o_wt = wt + o * cst.input_slice * cst.kernel_size;
auto b_in = in + b * cst.input_slice * cst.input_size;
auto b_in = in + b * cst.input_size;
for (auto z = 0; z < cst.input_slice; z++) {
for (auto ky = max_ky, iy = min_iy; ky >= min_ky; ky -= cst.delta_ky, iy += cst.delta_iy) {
for (auto kx = max_kx, ix = min_ix; kx >= min_kx; kx -= cst.delta_kx, ix += cst.delta_ix) {
auto wt4 = o_wt[z * cst.kernel_size + ky * cst.kernel_x + kx];
auto in4 = b_in[z * cst.input_size + iy * cst.input_width + ix];
auto in4 = b_in[z * cst.input_size * cst.batch + iy * cst.input_width + ix];
result += FLOAT4(in4 * wt4);
}
}
@ -78,7 +78,7 @@ kernel void deconv_depthwise(const device ftype4 *in [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return;
FLOAT4 result = FLOAT4(biasTerms[(int)(gid.z % cst.input_slice)]);
FLOAT4 result = FLOAT4(biasTerms[(int)(gid.z / cst.batch)]);
int oy = (int)gid.y + cst.pad_y;
int ox = (int)gid.x + cst.pad_x;

View File

@ -19,7 +19,7 @@ kernel void prelu_slopes(const device ftype4 *in [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { // size, slice, batch
if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;
int z = gid.z * s.slice + gid.y;
int z = gid.z + gid.y * s.batch;
auto v4 = in[z * s.size + int(gid.x)];
out[z * s.size + int(gid.x)] = select(v4, ftype4(slope[int(gid.y)]) * v4, signbit(v4));
}

View File

@ -2,10 +2,11 @@ struct ROI_shape {
int input_width;
int input_height;
int input_size;
int input_batch;
int output_width;
int output_height;
int output_size;
int slices;
int batch;
float spatial_scale;
};
kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
@ -15,10 +16,10 @@ kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;
int ob = gid.z / s.slices;
int iz = gid.z % s.slices;
int ob = gid.z % s.batch;
int iz = gid.z / s.batch;
auto b_roi = roi + ob * 8; // roundup(5, 4) = 8
auto b_roi = roi + ob * 5;
int ib = int(b_roi[0]);
int x1 = round(float(b_roi[1]) * s.spatial_scale);
int y1 = round(float(b_roi[2]) * s.spatial_scale);
@ -27,8 +28,8 @@ kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
int roi_w = max(x2 - x1 + 1, 1);
int roi_h = max(y2 - y1 + 1, 1);
auto bin_size_w = (ftype)roi_w / s.output_width;
auto bin_size_h = (ftype)roi_h / s.output_height;
float bin_size_w = (float)roi_w / (float)s.output_width;
float bin_size_h = (float)roi_h / (float)s.output_height;
int w_start = clamp(x1 + (int)floor(gid.x * bin_size_w) , 0, s.input_width);
int w_end = clamp(x1 + (int)ceil((gid.x + 1) * bin_size_w), 0, s.input_width);
@ -36,7 +37,7 @@ kernel void ROI_pooling(const device ftype4 *in [[buffer(0)]],
int h_end = clamp(y1 + (int)ceil((gid.y + 1) * bin_size_h), 0, s.input_height);
int is_empty = (h_end <= h_start) || (w_end <= w_start);
auto z_in = in + (ib * s.slices + iz) * s.input_size;
auto z_in = in + (ib + iz * s.input_batch) * s.input_size;
auto max4 = is_empty ? 0 : z_in[h_start * s.input_width + w_start];
for (int y = h_start; y < h_end; y++) {
auto y_in = z_in + y * s.input_width;

View File

@ -1,15 +0,0 @@
kernel void relu_x1(const device ftype *in [[buffer(0)]],
device ftype *out [[buffer(1)]],
constant float &slope [[buffer(2)]],
uint gid [[thread_position_in_grid]]) {
auto value = in[int(gid)];
out[int(gid)] = fmax(value, (ftype)0) + fmin(value, (ftype)0) * ftype(slope);
}
kernel void relu_x4(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant float &slope [[buffer(2)]],
uint gid [[thread_position_in_grid]]) {
auto value = in[int(gid)];
out[int(gid)] = fmax(value, (ftype4)0) + fmin(value, (ftype4)0) * ftype4(slope);
}

View File

@ -1,13 +1,24 @@
kernel void relu6_x1(const device ftype *in [[buffer(0)]],
device ftype *out [[buffer(1)]],
constant float4 &minMax [[buffer(2)]],
uint gid [[thread_position_in_grid]]) {
out[int(gid)] = clamp(in[int(gid)], (ftype)(minMax.x), (ftype)(minMax.y));
struct Param {
float minV;
float maxV;
int size;
int remain;
};
kernel void relu6(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant Param &p [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x < p.size) {
out[int(gid.x)] = clamp(in[int(gid.x)], (ftype4)p.minV, (ftype4)p.maxV);
}
}
kernel void relu6_x4(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant float4 &minMax [[buffer(2)]],
uint gid [[thread_position_in_grid]]) {
out[int(gid)] = clamp(in[int(gid)], (ftype4)minMax.x, (ftype4)minMax.y);
}
kernel void relu(const device ftype4 *in [[buffer(0)]],
device ftype4 *out [[buffer(1)]],
constant Param &p [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x < p.size) {
auto value = in[int(gid.x)];
out[int(gid.x)] = fmax(value, (ftype4)0) + fmin(value, (ftype4)0) * ftype4(p.minV);
}
}

View File

@ -12,7 +12,7 @@ kernel void scale_ca(const device ftype4 *in [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]) {
if ((int)gid.x >= s.size || (int)gid.y >= s.steps * s.batch) return;
int z = gid.y % s.steps;
int z = gid.y / s.batch;
out[int(gid.y) * s.size + int(gid.x)] =
in [int(gid.y) * s.size + int(gid.x)] * ftype4(scales[z]) + ftype4(biasTerms[z]);
}

View File

@ -294,6 +294,23 @@ private:
ImagePool* mBufferPool;
};
float OpenCLBackend::getBytes(const Tensor* tensor) {
float bytes = (float)tensor->getType().bytes();
if (getOpenCLRuntime()->isSupportedFP16()) {// Fp16
if (halide_type_float == tensor->getType().code) {
bytes = 2.0;
}
}
auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get();
if (nullptr != quant && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) {
bytes = 1.0;
}
if(tensor->getType().bits == 4) {
bytes = 0.5;
}
return bytes;
}
Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageType storageType) {
#ifdef LOG_VERBOSE
MNN_PRINT("Start OpenCLBackend::onAcquireBuffer !\n");
@ -312,7 +329,7 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
#ifndef MNN_OPENCL_BUFFER_CLOSED
if(mOpenCLRuntime->getGpuMemType() == BUFFER) {
size_t size;
size_t typeSize = 4;
float typeSize = getBytes(nativeTensor);
if (nativeTensor->dimensions() >= 2) {
auto alignC = ROUND_UP(C, 8);
// increment of height and width
@ -331,16 +348,9 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
size_t imageHeight = (size_t)N * H;
size = imageWidth*imageHeight*cPack;
}
cl_channel_type dataType = CL_FLOAT;
//when support and want fp16, use half datatype
if ((nativeTensor->getType().code == halide_type_int || nativeTensor->getType().code == halide_type_uint)){
if(nativeTensor->getType().bits == 8){
typeSize = 1;
}
} else if (getOpenCLRuntime()->isSupportedFP16()) {
typeSize = 2;
}
// Align when int4 memory
size = ROUND_UP(size, 2);
if (storageType == DYNAMIC_SEPERATE) {
auto buffer = mBufferPool->alloc(size*typeSize, true);
((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer;
@ -352,23 +362,8 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
return new CLMemReleaseBuffer(buffer, mBufferPool);
}
MNN_ASSERT(storageType == STATIC);
#ifdef MNN_LOW_MEMORY
// for weight quant model's weight
if ((nativeTensor->getType().code == halide_type_int) &&
(nativeTensor->getType().bits == 8 || nativeTensor->getType().bits == 4)) {
// int8 quant
size_t alloc_size = size;
if (nativeTensor->getType().bits == 4) {
// int4 quant
alloc_size = size / 2;
}
auto buffer = mStaticBufferPool->alloc(alloc_size);
((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer;
return new CLMemReleaseBuffer(buffer, mStaticBufferPool.get());
}
#endif
auto buffer = mStaticBufferPool->alloc(size*
(dataType == CL_HALF_FLOAT ? sizeof(half_float::half) : sizeof(float)));
auto buffer = mStaticBufferPool->alloc(size*typeSize);
((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer; // fix
return new CLMemReleaseBuffer(buffer, mStaticBufferPool.get());
}
@ -569,7 +564,7 @@ ErrorCode OpenCLBackend::onResizeEnd() {
mOpenCLRuntime->setCommandQueueProfileDisable();
#endif
if(!mRecordings.empty()){
endRecord(mRecordings.back(), true);
endRecord(mRecordings.back().record, true);
}
return NO_ERROR;
}
@ -1188,8 +1183,25 @@ void OpenCLBackend::clearRecord() const{
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(mUseRecordQueue && mDevideOpRecord){
for(int i = 0; i < mRecordings.size(); ++i){
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr,
0, nullptr, 0, nullptr, 0, nullptr, nullptr);
std::vector<cl_array_arg_qcom> update_kernel_args;
std::vector<cl_workgroup_qcom> update_global_size;
std::vector<cl_workgroup_qcom> update_local_size;
for (int j = 0; j < mRecordings[i].updateInfo.size(); ++j){
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_kernel_args.size(); ++k){
update_kernel_args.emplace_back(mRecordings[i].updateInfo[j]->update_kernel_args[k]);
update_kernel_args.back().dispatch_index = j;
}
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_global_size.size(); ++k){
update_global_size.emplace_back(mRecordings[i].updateInfo[j]->update_global_size[k]);
update_global_size.back().dispatch_index = j;
}
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_local_size.size(); ++k){
update_local_size.emplace_back(mRecordings[i].updateInfo[j]->update_local_size[k]);
update_local_size.back().dispatch_index = j;
}
}
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i].record, update_kernel_args.size(), update_kernel_args.data(), 0, nullptr,
update_global_size.size(), update_global_size.data(), update_local_size.size(), update_local_size.data(), 0, nullptr, nullptr);
MNN_CHECK_CL_SUCCESS(res, "EnqueueRecordingQCOM");
}
mOpenCLRuntime->commandQueue().finish();
@ -1202,8 +1214,22 @@ void OpenCLBackend::enqeueRecord() const{
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(mUseRecordQueue && !mDevideOpRecord){
for(int i = 0; i < mRecordings.size(); ++i){
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr,
0, nullptr, 0, nullptr, 0, nullptr, nullptr);
std::vector<cl_array_arg_qcom> update_kernel_args;
std::vector<cl_workgroup_qcom> update_global_size;
std::vector<cl_workgroup_qcom> update_local_size;
for (int j = 0; j < mRecordings[i].updateInfo.size(); ++j){
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_kernel_args.size(); ++k){
update_kernel_args.emplace_back(mRecordings[i].updateInfo[j]->update_kernel_args[k]);
}
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_global_size.size(); ++k){
update_global_size.emplace_back(mRecordings[i].updateInfo[j]->update_global_size[k]);
}
for(int k = 0; k < mRecordings[i].updateInfo[j]->update_local_size.size(); ++k){
update_local_size.emplace_back(mRecordings[i].updateInfo[j]->update_local_size[k]);
}
}
cl_int res = mOpenCLRuntime->commandQueue().EnqueueRecordingQCOM(mRecordings[i].record, update_kernel_args.size(), update_kernel_args.data(), 0, nullptr,
update_global_size.size(), update_global_size.data(), update_local_size.size(), update_local_size.data(), 0, nullptr, nullptr);
MNN_CHECK_CL_SUCCESS(res, "EnqueueRecordingQCOM");
}
mOpenCLRuntime->commandQueue().finish();
@ -1215,7 +1241,7 @@ void OpenCLBackend::releaseRecord(){
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(mUseRecordQueue && !mDevideOpRecord){
for(int i = 0; i < mRecordings.size(); ++i){
cl_int res = clReleaseRecordingQCOM(mRecordings[i]);
cl_int res = clReleaseRecordingQCOM(mRecordings[i].record);
MNN_CHECK_CL_SUCCESS(res, "clReleaseRecordingQCOM");
}
mRecordings.clear();
@ -1258,8 +1284,9 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){
res = clEndRecordingQCOM(recording);
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
} else if(flag) {
// endRecord for last kernel be recorded when record mode is MNN_GPU_RECORD_BATCH
if(!mRecordings.empty()){
cl_int res = clEndRecordingQCOM(mRecordings.back());
cl_int res = clEndRecordingQCOM(mRecordings.back().record);
mRecordNums = 0;
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
}
@ -1270,7 +1297,18 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){
#endif //ENABLE_OPENCL_TIME_PROFILER
}
void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws) {
void OpenCLBackend::addRecord(cl_recording_qcom &record, std::vector<RecordUpdateInfo *>updateInfo){
if(mDevideOpRecord){
RecordInfo info;
info.record = record;
for(int i = 0; i < updateInfo.size(); ++i) {
info.updateInfo.emplace_back(updateInfo[i]);
}
mRecordings.emplace_back(info);
}
}
void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo) {
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(!mUseRecordQueue){
return;
@ -1281,17 +1319,35 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, c
#endif
cl_int res = CL_SUCCESS;
if(!mDevideOpRecord){
RecordInfo info;
if(updateInfo != nullptr){
for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){
updateInfo->update_kernel_args[i].dispatch_index = mRecordNums;
}
for(int i = 0; i < updateInfo->update_global_size.size(); ++i){
updateInfo->update_global_size[i].dispatch_index = mRecordNums;
}
for(int i = 0; i < updateInfo->update_local_size.size(); ++i){
updateInfo->update_local_size[i].dispatch_index = mRecordNums;
}
info.updateInfo.emplace_back(updateInfo);
}
if(mRecordNums == 0){
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
mRecordings.emplace_back(recording);
info.record = recording;
mRecordings.emplace_back(info);
}else if(mRecordNums == mUseRecordableQueueSize){
res = clEndRecordingQCOM(mRecordings.back());
res = clEndRecordingQCOM(mRecordings.back().record);
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
mRecordings.emplace_back(recording);
info.record = recording;
mRecordings.emplace_back(info);
mRecordNums = 0;
} else if(updateInfo != nullptr){
auto &lastInfo = mRecordings.back();
lastInfo.updateInfo.emplace_back(updateInfo);
}
mRecordNums++;
}
@ -1317,7 +1373,7 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, c
#endif //ENABLE_OPENCL_TIME_PROFILER
}
void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws) {
void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo) {
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(!mUseRecordQueue){
return;
@ -1332,17 +1388,35 @@ void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, c
internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i]));
}
if(!mDevideOpRecord){
RecordInfo info;
if(updateInfo != nullptr){
for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){
updateInfo->update_kernel_args[i].dispatch_index = mRecordNums;
}
for(int i = 0; i < updateInfo->update_global_size.size(); ++i){
updateInfo->update_global_size[i].dispatch_index = mRecordNums;
}
for(int i = 0; i < updateInfo->update_local_size.size(); ++i){
updateInfo->update_local_size[i].dispatch_index = mRecordNums;
}
info.updateInfo.emplace_back(updateInfo);
}
if(mRecordNums == 0){
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
mRecordings.emplace_back(recording);
info.record = recording;
mRecordings.emplace_back(info);
}else if(mRecordNums == mUseRecordableQueueSize){
res = clEndRecordingQCOM(mRecordings.back());
res = clEndRecordingQCOM(mRecordings.back().record);
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
cl_recording_qcom recording = mOpenCLRuntime->recordableQueue().NewRecordingQCOM(&res);
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
mRecordings.emplace_back(recording);
info.record = recording;
mRecordings.emplace_back(info);
mRecordNums = 0;
} else if(updateInfo != nullptr){
auto &lastInfo = mRecordings.back();
lastInfo.updateInfo.emplace_back(updateInfo);
}
mRecordNums++;
}

View File

@ -34,6 +34,15 @@
namespace MNN {
namespace OpenCL {
struct TuneInfo;
struct RecordUpdateInfo{
std::vector<cl_array_arg_qcom> update_kernel_args;
std::vector<cl_workgroup_qcom> update_global_size;
std::vector<cl_workgroup_qcom> update_local_size;
};
struct RecordInfo{
cl_recording_qcom record;
std::vector<RecordUpdateInfo*> updateInfo;
};
class CLRuntime : public Runtime {
public:
CLRuntime(const Backend::Info& info, int platformSize, int platformId, int deviceId = 0, void *contextPtr = nullptr, void *glshared = nullptr);
@ -111,6 +120,7 @@ public:
return mMemory;
}
float getBytes(const Tensor* tensor);
DataType getDataType(const Tensor* tensor);
cl_channel_type fpType();
@ -125,11 +135,9 @@ public:
bool isDevideOpRecord(){
return mDevideOpRecord;
}
void addRecord(cl_recording_qcom &record){
mRecordings.emplace_back(record);
}
void recordKernel2d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws);
void recordKernel3d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws);
void addRecord(cl_recording_qcom &record, std::vector<RecordUpdateInfo *>updateInfo);
void recordKernel2d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo = nullptr);
void recordKernel3d(const std::shared_ptr<KernelWrap> &kernel, const std::vector<uint32_t> &gws, const std::vector<uint32_t> &lws, RecordUpdateInfo *updateInfo = nullptr);
void startRecord(cl_recording_qcom &recording);
void endRecord(cl_recording_qcom &recording, bool flag = false);
@ -167,7 +175,7 @@ private:
BackendConfig::PrecisionMode mPrecision;
BackendConfig::MemoryMode mMemory;
bool mIsCreateError{false};
mutable std::vector<cl_recording_qcom> mRecordings;
mutable std::vector<RecordInfo> mRecordings;
bool mUseRecordQueue = false;
bool mDevideOpRecord = false;
uint32_t mRecordNums = 0;
@ -202,6 +210,17 @@ public:
}
#endif
#ifdef MNN_OPENCL_SEP_BUILD
#define REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(name, opType, memObj) \
OpenCLCreatorRegister<name> ___OpenCL##name##__##opType##__##memObj##__(opType, memObj)
#else
#define REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(name, opType, memObj) \
void ___OpenCL##name##__##opType##__##memObj##__() { \
static name _temp; \
OpenCLBackend::addCreator(std::make_pair(opType, memObj), &_temp); \
}
#endif
template <typename T>
class TypedCreator : public OpenCLBackend::Creator {

View File

@ -1,122 +1,127 @@
// This file is generated by Shell for ops register
#ifndef MNN_OPENCL_SEP_BUILD
namespace MNN {
namespace OpenCL {
extern void ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
extern void ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
extern void ___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__();
extern void ___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__();
extern void ___OpenCLScaleCreator__OpType_Scale__IMAGE__();
extern void ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
extern void ___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
extern void ___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
extern void ___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
extern void ___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
extern void ___OpenCLInterpCreator__OpType_Interp__IMAGE__();
extern void ___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
extern void ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
extern void ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
extern void ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
extern void ___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__();
extern void ___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
extern void ___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
extern void ___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
extern void ___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
extern void ___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
extern void ___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
extern void ___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
extern void ___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
extern void ___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
extern void ___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__();
extern void ___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
extern void ___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
extern void ___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
extern void ___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
extern void ___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__();
extern void ___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
extern void ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
extern void ___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
extern void ___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
extern void ___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
extern void ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
extern void ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
extern void ___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
extern void ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
extern void ___OpenCLReluCreator__OpType_ReLU__IMAGE__();
extern void ___OpenCLReluCreator__OpType_PReLU__IMAGE__();
extern void ___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
extern void ___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
extern void ___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
extern void ___OpenCLReluBufCreator__OpType_ReLU__BUFFER__();
extern void ___OpenCLReluBufCreator__OpType_PReLU__BUFFER__();
extern void ___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__();
extern void ___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
extern void ___OpenCLRasterCreator__OpType_Raster__IMAGE__();
extern void ___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
extern void ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
extern void ___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
extern void ___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
extern void ___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
extern void ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
extern void ___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
extern void ___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
extern void ___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
extern void ___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
extern void ___OpenCLLoopCreator__OpType_While__IMAGE__();
extern void ___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
extern void ___OpenCLLoopBufCreator__OpType_While__BUFFER__();
extern void ___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
extern void ___OpenCLFuseCreator__OpType_Extra__IMAGE__();
extern void ___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
extern void ___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
extern void ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
extern void ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
extern void ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
extern void ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
extern void ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
extern void ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
extern void ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
extern void ___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
extern void ___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
extern void ___OpenCLScaleCreator__OpType_Scale__IMAGE__();
extern void ___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
extern void ___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
extern void ___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
extern void ___OpenCLRangeCreator__OpType_Range__IMAGE__();
extern void ___OpenCLCastCreator__OpType_Cast__IMAGE__();
extern void ___OpenCLRasterCreator__OpType_Raster__IMAGE__();
extern void ___OpenCLFuseCreator__OpType_Extra__IMAGE__();
extern void ___OpenCLLoopCreator__OpType_While__IMAGE__();
extern void ___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
extern void ___OpenCLReluCreator__OpType_ReLU__IMAGE__();
extern void ___OpenCLReluCreator__OpType_PReLU__IMAGE__();
extern void ___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
extern void ___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
extern void ___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
extern void ___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
extern void ___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
extern void ___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
extern void ___OpenCLSelectCreator__OpType_Select__IMAGE__();
extern void ___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
extern void ___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
extern void ___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
extern void ___OpenCLCastCreator__OpType_Cast__IMAGE__();
extern void ___OpenCLInterpCreator__OpType_Interp__IMAGE__();
extern void ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
extern void ___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__();
#endif
void registerOpenCLOps() {
___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__();
___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__();
___OpenCLScaleCreator__OpType_Scale__IMAGE__();
___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
___OpenCLInterpCreator__OpType_Interp__IMAGE__();
___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__();
___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__();
___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__();
___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
___OpenCLReluCreator__OpType_ReLU__IMAGE__();
___OpenCLReluCreator__OpType_PReLU__IMAGE__();
___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
___OpenCLReluBufCreator__OpType_ReLU__BUFFER__();
___OpenCLReluBufCreator__OpType_PReLU__BUFFER__();
___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__();
___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
___OpenCLRasterCreator__OpType_Raster__IMAGE__();
___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
___OpenCLLoopCreator__OpType_While__IMAGE__();
___OpenCLLoopBufCreator__OpType_While__BUFFER__();
___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
___OpenCLFuseCreator__OpType_Extra__IMAGE__();
___OpenCLRangeCreator__OpType_Range__IMAGE__();
___OpenCLCastCreator__OpType_Cast__IMAGE__();
___OpenCLSelectCreator__OpType_Select__IMAGE__();
___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__();
___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__();
___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__();
___OpenCLArgMaxBufCreator__OpType_ArgMin__BUFFER__();
___OpenCLMatMulBufCreator__OpType_MatMul__BUFFER__();
___OpenCLRasterBufCreator__OpType_Raster__BUFFER__();
___OpenCLLayerNormBufCreator__OpType_LayerNorm__BUFFER__();
___OpenCLDepthwiseConvolutionBufCreator__OpType_ConvolutionDepthwise__BUFFER__();
___OpenCLInterpBufCreator__OpType_Interp__BUFFER__();
___OpenCLBinaryBufCreator__OpType_Eltwise__BUFFER__();
___OpenCLBinaryBufCreator__OpType_BinaryOp__BUFFER__();
___OpenCLConvolutionBufCreator__OpType_Convolution__BUFFER__();
___OpenCLSelectBufCreator__OpType_Select__BUFFER__();
___OpenCLPoolBufCreator__OpType_Pooling__BUFFER__();
___OpenCLDeconvolutionBufCreator__OpType_Deconvolution__BUFFER__();
___OpenCLCastBufCreator__OpType_Cast__BUFFER__();
___OpenCLReluBufCreator__OpType_ReLU__BUFFER__();
___OpenCLReluBufCreator__OpType_PReLU__BUFFER__();
___OpenCLReluBufCreator__OpType_ReLU6__BUFFER__();
___OpenCLSoftmaxBufCreator__OpType_Softmax__BUFFER__();
___OpenCLLoopBufCreator__OpType_While__BUFFER__();
___OpenCLRangeBufCreator__OpType_Range__BUFFER__();
___OpenCLUnaryBufCreator__OpType_UnaryOp__BUFFER__();
___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__();
___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__();
___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__();
___OpenCLScaleBufCreator__OpType_Scale__BUFFER__();
___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__();
___OpenCLMatMulCreator__OpType_MatMul__IMAGE__();
___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__();
___OpenCLUnaryCreator__OpType_Sigmoid__IMAGE__();
___OpenCLUnaryCreator__OpType_TanH__IMAGE__();
___OpenCLScaleCreator__OpType_Scale__IMAGE__();
___OpenCLSoftmaxCreator__OpType_Softmax__IMAGE__();
___OpenCLEltwiseCreator__OpType_Eltwise__IMAGE__();
___OpenCLEltwiseCreator__OpType_BinaryOp__IMAGE__();
___OpenCLRangeCreator__OpType_Range__IMAGE__();
___OpenCLRasterCreator__OpType_Raster__IMAGE__();
___OpenCLFuseCreator__OpType_Extra__IMAGE__();
___OpenCLLoopCreator__OpType_While__IMAGE__();
___OpenCLTrainableParamCreator__OpType_TrainableParam__IMAGE__();
___OpenCLReluCreator__OpType_ReLU__IMAGE__();
___OpenCLReluCreator__OpType_PReLU__IMAGE__();
___OpenCLReluCreator__OpType_ReLU6__IMAGE__();
___OpenCLConvolutionCreator__OpType_Convolution__IMAGE__();
___OpenCLLayerNormCreator__OpType_LayerNorm__IMAGE__();
___OpenCLReductionCreator__OpType_Reduction__IMAGE__();
___OpenCLRoiPoolingCreator__OpType_ROIPooling__IMAGE__();
___OpenCLPoolCreator__OpType_Pooling__IMAGE__();
___OpenCLSelectCreator__OpType_Select__IMAGE__();
___OpenCLDeconvolutionCreator__OpType_Deconvolution__IMAGE__();
___OpenCLDepthwiseDeconvolutionCreator__OpType_DeconvolutionDepthwise__IMAGE__();
___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__();
___OpenCLCastCreator__OpType_Cast__IMAGE__();
___OpenCLInterpCreator__OpType_Interp__IMAGE__();
___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__();
#endif
}
}
}

View File

@ -548,9 +548,7 @@ void runKernel2D(const ::std::shared_ptr<KernelWrap> &kernelw, const std::vector
void copyBufferToImage(OpenCLRuntime *runtime, const cl::Buffer &buffer, const cl::Image &image, int w, int h) {
std::set<std::string> buildOptions;
if(runtime->isWeightCpuTransHalf() == false) {
buildOptions.emplace("-DBUFFER_INP_FP32");
}
buildOptions.emplace("-DBUFFER_INP_FP32");
auto kernelW = runtime->buildKernelWithCache("copy_buffer_to_image2d", "copy_buffer_to_image2d", buildOptions);
auto kernel = kernelW->get();
auto status = kernel.setArg(0, buffer);

View File

@ -230,9 +230,11 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
mFirstGPUDevicePtr->getInfo(CL_DEVICE_MAX_COMPUTE_UNITS, &mGPUComputeUnits);
mFirstGPUDevicePtr->getInfo(CL_DEVICE_MAX_CLOCK_FREQUENCY, &mMaxFreq);
mFirstGPUDevicePtr->getInfo(CL_DEVICE_MAX_MEM_ALLOC_SIZE, &mMaxMemAllocSize);
cl_device_fp_config fpConfig;
cl_device_fp_config fpConfig;
auto success = mFirstGPUDevicePtr->getInfo(CL_DEVICE_HALF_FP_CONFIG, &fpConfig);
mIsDeviceSupportedFP16 = CL_SUCCESS == success && fpConfig > 0;
bool checkFp16Exetension = getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_khr_fp16");
mIsDeviceSupportedFP16 = (mIsDeviceSupportedFP16 && checkFp16Exetension);
//set gpu mode, tuning level and memory object
setGpuMode(cl_mode);
@ -245,11 +247,17 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
}
}
auto permitFloat16 = false;
if (precision == BackendConfig::Precision_Low || (mMemType == BUFFER && precision == BackendConfig::Precision_Normal)) {//buffer mode not support Normal Precision yet
permitFloat16 = true;
mPrecisionLevel = 1;
if (mIsDeviceSupportedFP16) {
if (precision == BackendConfig::Precision_Low) {
mPrecisionLevel = 2;
} else if (precision == BackendConfig::Precision_Normal && mMemType == BUFFER) {
mPrecisionLevel = 0;
}
}
mIsSupportedFP16 = mIsDeviceSupportedFP16 && permitFloat16;
// Is supported fp16 IO storage
mIsSupportedFP16 = (mPrecisionLevel == 2 || mPrecisionLevel == 0);
if(getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_arm_integer_dot_product_int8")){
mSupportDotInt8 = true;
@ -257,7 +265,7 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
if(getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_arm_integer_dot_product_accumulate_int8")){
mSupportDotAccInt8 = true;
}
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
{
if((false == OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isQcomError())
@ -267,7 +275,7 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
cl_int err;
if(MaxRecordableQueueSize > 0){
// TODO: Use setSessionHint to set the number of mUseRecordableQueueSize
mUseRecordableQueueSize = 10;
mUseRecordableQueueSize = MaxRecordableQueueSize;
mUseRecordableQueueSize = MaxRecordableQueueSize < mUseRecordableQueueSize ? MaxRecordableQueueSize : mUseRecordableQueueSize;
mUseRecordQueue = true;
mRecordableQueuePtr = std::make_shared<cl::CommandQueue>(*mContext, *mFirstGPUDevicePtr, CL_QUEUE_RECORDABLE_QCOM, &err);
@ -435,13 +443,6 @@ std::vector<size_t> OpenCLRuntime::getMaxImage2DSize() {
bool OpenCLRuntime::isSupportedFP16() const {
return mIsSupportedFP16;
}
bool OpenCLRuntime::isWeightCpuTransHalf() const {
#ifdef USE_HALF_WEIGHT_MEMORY
return mIsSupportedFP16;
#else
return false;//most of time
#endif
}
bool OpenCLRuntime::isDeviceSupportedFP16() const {
return mIsDeviceSupportedFP16;
@ -536,10 +537,12 @@ std::shared_ptr<KernelWrap> OpenCLRuntime::buildKernel(const std::string &progra
std::shared_ptr<KernelWrap> OpenCLRuntime::buildKernelWithCache(const std::string &programName, const std::string &kernelName,
const std::set<std::string> &buildOptions, const Tensor *input, const Tensor *output, bool useCache) {
std::string buildOptionsStr;
if (mIsSupportedFP16) {
buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DRI_F=read_imageh -DWI_F=write_imageh -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DMNN_SUPPORT_FP16";
} else {
buildOptionsStr = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16";
if (mPrecisionLevel == 2) {// Fp16 Memory and fp16 compute
buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=half -DCOMPUTE_FLOAT2=half2 -DCOMPUTE_FLOAT3=half3 -DCOMPUTE_FLOAT4=half4 -DCOMPUTE_FLOAT8=half8 -DCOMPUTE_FLOAT16=half16 -DCONVERT_COMPUTE_FLOAT2=convert_half2 -DCONVERT_COMPUTE_FLOAT4=convert_half4 -DCONVERT_COMPUTE_FLOAT8=convert_half8 -DCONVERT_COMPUTE_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DMNN_SUPPORT_FP16";
} else if (mPrecisionLevel == 0) {// Fp16 Memory and fp32 compute
buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DMNN_SUPPORT_FP16";
} else {// Fp32 Memory and fp32 compute
buildOptionsStr = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16";
}
if(nullptr != input){
@ -907,6 +910,7 @@ void OpenCLRuntime::printEventTime(){
return;
}
int raster_num = 0, raster_time = 0;
unsigned int conv_time = 0, while_time = 0;
for(int i = 0; i < mEvents.size(); ++i){
auto event = &mEvents[i].second;
cl_int res = event->wait();
@ -915,10 +919,16 @@ void OpenCLRuntime::printEventTime(){
auto StopNanos = event->getProfilingInfo<CL_PROFILING_COMMAND_END>();
auto kernel_time = (unsigned int)((StopNanos - StartNanos) / 1000.0);
mKernelTime += kernel_time;
if(mEvents[i].first == "ConvBuf2D" || (mEvents[i].first.length() >= 11 && mEvents[i].first.substr(0, 11) == "Convolution")) {
conv_time += kernel_time;
}
if((mEvents[i].first.length() >= 5 && mEvents[i].first.substr(0, 5) == "While")) {
while_time += kernel_time;
}
MNN_PRINT("kernel time = %d us %s\n", kernel_time, mEvents[i].first.c_str());
}
mEvents.clear();
MNN_PRINT("total kernel time = %d us\n", mKernelTime);
MNN_PRINT("total kernel time = %d us, conv time = %d us, while time = %d us\n", mKernelTime, conv_time, while_time);
#endif
}
} // namespace MNN

View File

@ -74,7 +74,6 @@ public:
OpenCLRuntime &operator=(const OpenCLRuntime &) = delete;
bool isSupportedFP16() const;
bool isWeightCpuTransHalf() const;
bool isDeviceSupportedFP16() const;
bool isDeviceSupportedLowPower() const;
bool isSupportedDotInt8() const;
@ -197,7 +196,9 @@ private:
uint32_t mUseRecordableQueueSize;
bool mUseRecordQueue = false;
bool mDevideOpRecord = true;
bool mIsSupportedFP16 = false;
int mPrecisionLevel;
bool mIsSupportedFP16 = false;
bool mIsDeviceSupportedFP16 = false;
bool mIsDeviceSupportedLowPower = false;
bool mSupportDotInt8 = false;

View File

@ -33,6 +33,14 @@ bool OpenCLSymbols::LoadOpenCLLibrary() {
"libGLES_mali.so",
"libmali.so",
"libOpenCL-pixel.so",
/*
#elif defined(__OHOS__)
"/vendor/lib64/chipsetsdk/libGLES_mali.so",
"/system/lib64/libGLES_mali.so",
"libGLES_mali.so",
"/vendor/lib64/chipsetsdk/libhvgr_v200.so",
"/vendor/lib64/chipsetsdk/libEGI_imp1.so",
*/
#if defined(__aarch64__)
// Qualcomm Adreno
"/system/vendor/lib64/libOpenCL.so",
@ -110,7 +118,7 @@ bool OpenCLSymbols::isSvmError() {
bool OpenCLSymbols::isPropError() {
return mPropError;
}
bool OpenCLSymbols::isQcomError() {
return mQcomError;
}
@ -118,11 +126,11 @@ bool OpenCLSymbols::isQcomError() {
bool OpenCLSymbols::isGlError() {
return mGlError;
}
bool OpenCLSymbols::isCL1_2Error() {
return mCL_12Error;
}
bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
#if defined(WIN32)
handle_ = LoadLibraryA(library_path.c_str());
@ -133,38 +141,38 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mIsError = true; \
}
#define MNN_LOAD_SVM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
if(func_name == nullptr){ \
mSvmError = true; \
}
#define MNN_LOAD_PROP_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
if(func_name == nullptr){ \
mPropError = true; \
}
#define MNN_LOAD_QCOM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
if(func_name == nullptr){ \
mQcomError = true; \
}
#define MNN_LOAD_CL_12_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
if(func_name == nullptr){ \
mCL_12Error = true; \
}
#define MNN_LOAD_GL_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(handle_, #func_name)); \
if(func_name == nullptr){ \
mGlError = true; \
}
#else
handle_ = dlopen(library_path.c_str(), RTLD_NOW | RTLD_LOCAL);
if (handle_ == nullptr) {
return false;
}
typedef void* (*loadOpenCLPointerFunc)(const char* name);
typedef void (*enableOpenCLFunc)();
loadOpenCLPointerFunc loadOpenCLPointer = nullptr;
@ -180,7 +188,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mIsError = true; \
}
#define MNN_LOAD_SVM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
@ -188,7 +196,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mSvmError = true; \
}
#define MNN_LOAD_PROP_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
@ -196,7 +204,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mPropError = true; \
}
#define MNN_LOAD_QCOM_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
@ -204,7 +212,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mQcomError = true; \
}
#define MNN_LOAD_CL_12_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
@ -212,7 +220,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mCL_12Error = true; \
}
#define MNN_LOAD_GL_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name)); \
if(func_name == nullptr && loadOpenCLPointer != nullptr){ \
func_name = reinterpret_cast<func_name##Func>(loadOpenCLPointer(#func_name)); \
@ -220,7 +228,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
if(func_name == nullptr){ \
mGlError = true; \
}
#endif
MNN_LOAD_FUNCTION_PTR(clGetPlatformIDs);
@ -277,7 +285,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) {
MNN_LOAD_CL_12_PTR(clCreateImage);
MNN_LOAD_CL_12_PTR(clRetainDevice);
MNN_LOAD_CL_12_PTR(clReleaseDevice);
MNN_LOAD_PROP_PTR(clCreateCommandQueueWithProperties);
MNN_LOAD_SVM_PTR(clSVMAlloc);
MNN_LOAD_SVM_PTR(clSVMFree);
@ -707,7 +715,7 @@ cl_mem CL_API_CALL clCreateFromGLTexture(cl_context context,
auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateFromGLTexture;
MNN_CHECK_NOTNULL(func);
return func(context, flags, target, miplevel, texture, errcode_ret);
}
cl_int CL_API_CALL clEnqueueAcquireGLObjects(cl_command_queue command_queue,

View File

@ -0,0 +1,427 @@
//
// SoftmaxBufExecution.cpp
// MNN
//
// Created by MNN on 2024/04/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_OPENCL_BUFFER_CLOSED
#include "backend/opencl/execution/buffer/AttentionBufExecution.hpp"
namespace MNN {
namespace OpenCL {
AttentionBufImpl::AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cahce)
: mKVCache(kv_cahce){
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_channel", {"-DSOFTMAX_LOCAL_SIZE=512"});
mMaxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel));
}
void AttentionBufImpl::allocKVCache() {
if (!mKVCache || mPastLength < mMaxLength) {
return;
}
mMaxLength = mPastLength + mExpandChunk;
int byte = 4;
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
byte = 2;
}
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * mHeadDim * 4 * byte;
// past_key: [1, numhead, headdim, maxlen]
mPastKey.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
// past_value: [1, numhead, maxlen, headdim]
mPastValue.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
}
void AttentionBufImpl::reallocKVCache() {
if (!mKVCache || mPastLength < mMaxLength) {
return;
}
int byte = 4;
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
byte = 2;
}
size_t old_size = mNumHead * UP_DIV(mMaxLength, 4) * mHeadDim * 4 * byte;
mMaxLength = mPastLength + mExpandChunk;
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * mHeadDim * 4 * byte;
// past_key: [1, numhead, headdim, maxlen]
auto new_key = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
// past_value: [1, numhead, maxlen, headdim]
auto new_value = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
// copy
cl_int res;
auto new_key_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*new_key, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
auto key_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastKey.get(), true, CL_MAP_READ, 0, old_size, nullptr, nullptr, &res);
if(new_key_ptr != nullptr && key_ptr != nullptr && res == CL_SUCCESS){
::memcpy(new_key_ptr, key_ptr, old_size);
}else{
MNN_ERROR("Map error key_ptr == nullptr \n");
MNN_ASSERT(false);
}
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*new_key, new_key_ptr);
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mPastKey.get(), key_ptr);
auto new_value_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*new_value, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
auto value_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastValue.get(), true, CL_MAP_READ, 0, old_size, nullptr, nullptr, &res);
if(new_value_ptr != nullptr && value_ptr != nullptr && res == CL_SUCCESS){
::memcpy(new_value_ptr, value_ptr, old_size);
}else{
MNN_ERROR("Map error value_ptr == nullptr \n");
MNN_ASSERT(false);
}
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*new_value, new_value_ptr);
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mPastValue.get(), value_ptr);
mPastKey.reset(new_key);
mPastValue.reset(new_value);
size_t temp_size = UP_DIV(mMaxLength, 4) * mNumHead * 4 * byte;
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, temp_size));
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, temp_size));
// reset memory for args
if(mOpenCLBackend->isUseRecordQueue()){
mQkUpdateInfo.update_kernel_args[1].arg_value = &(*(mTempQK.get()))();
mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mPastKey.get()))();
mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &(*(mTempQK.get()))();
mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &(*(mTempSoftMax.get()))();
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mTempSoftMax.get()))();
mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mPastValue.get()))();
}else{
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(5, *mTempQK.get());
ret |= mKernel_qk->get().setArg(6, *mPastKey.get());
ret |= mKernel_softmax->get().setArg(3, *mTempQK.get());
ret |= mKernel_softmax->get().setArg(4, *mTempSoftMax.get());
ret |= mKernel_qkv->get().setArg(3, *mTempSoftMax.get());
ret |= mKernel_qkv->get().setArg(6, *mPastValue.get());
MNN_CHECK_CL_SUCCESS(ret, "reset memory arg for AttentionBufExecution");
}
}
int AttentionBufImpl::getLocalSize(int size, int maxGroupSize){
int local_size = 1;
while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){
local_size *= 2;
}
return local_size;
}
ErrorCode AttentionBufImpl::onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
mOpenCLBackend->startRecord(mRecording);
//clear update arg vector, if prefill and decode use the same one
mOpRecordUpdateInfo.clear();
mQkUpdateInfo.update_kernel_args.clear();
mQkUpdateInfo.update_global_size.clear();
mQkUpdateInfo.update_local_size.clear();
mSoftMaxUpdateInfo.update_kernel_args.clear();
mSoftMaxUpdateInfo.update_global_size.clear();
mSoftMaxUpdateInfo.update_local_size.clear();
mQkvUpdateInfo.update_kernel_args.clear();
mQkvUpdateInfo.update_global_size.clear();
mQkvUpdateInfo.update_local_size.clear();
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto mask = inputs[3];
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto shape = query->shape();
int seq_len = shape[1];
mNumHead = shape[2];
mHeadDim = shape[3];
mScale = 1.0 / sqrt(mHeadDim);
mIsDecode = seq_len == 1;
mIsFirstDecode = true;
if (mPastLength == 0 || seq_len > 1) {
mPastLength = seq_len;
}
mKv_seq_len = mPastLength;
if(mIsDecode){
mKv_seq_len = mPastLength + 1;
}
allocKVCache();
int byte = 4;
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
byte = 2;
}
if (mIsDecode) {
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * 4 * byte;
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
} else {
size_t buffer_size = UP_DIV(mPastLength, 4) * mPastLength * mNumHead * 4 * byte;
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
}
// query * key -> div -> select
{
std::set<std::string> buildOption;
if(!mIsDecode){
buildOption.emplace("-DOPENCL_PREFILL_ATTENTION");
}
if((mHeadDim % 4) != 0){
buildOption.emplace("-DHEADDIM_LEAVE");
}
mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask", buildOption, inputs[0], outputs[0]);
mGlobalWorkSizeQk = {static_cast<uint32_t>(UP_DIV(seq_len, 4)), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(UP_DIV(mKv_seq_len, 4))};
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qk));
mGlobalWorkSizeQk2 = UP_DIV(mKv_seq_len, 4);
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk2);
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query));
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(key));
ret |= mKernel_qk->get().setArg(index++, *mTempQK.get());
ret |= mKernel_qk->get().setArg(index++, *mPastKey.get());
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mask));
ret |= mKernel_qk->get().setArg(index++, mScale);
ret |= mKernel_qk->get().setArg(index++, seq_len);
ret |= mKernel_qk->get().setArg(index++, mKv_seq_len);
ret |= mKernel_qk->get().setArg(index++, mNumHead);
ret |= mKernel_qk->get().setArg(index++, mHeadDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_div_mask");
mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask", mKernel_qk).first;
mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2]));
mQkUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(mGlobalWorkSizeQk2), &mGlobalWorkSizeQk2});
mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &(*(mTempQK.get()))()});
mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastKey.get()))()});
mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKv_seq_len), &mKv_seq_len});
mQkGlobal_size[0] = mGlobalWorkSizeQk[0];
mQkGlobal_size[1] = mGlobalWorkSizeQk[1];
mQkGlobal_size[2] = mGlobalWorkSizeQk[2];
mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size});
mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo);
}
// softmax
{
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
int localSize = getLocalSize(mKv_seq_len, MaxLocalSize);
if(localSize < 4){
localSize = 1;
}
int past_len4 = UP_DIV(mKv_seq_len, 4);
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
mSoftmaxShape[0] = mNumHead;
mSoftmaxShape[1] = past_len4;
mSoftmaxShape[2] = 1;
mSoftmaxShape[3] = mPastLength;
std::set<std::string> buildOption;
buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
if(!mIsDecode){
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_width", buildOption, inputs[0], outputs[0]);
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(past_len4), static_cast<uint32_t>(mNumHead)};
} else{
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_channel", buildOption, inputs[0], outputs[0]);
mSoftmaxShape[3] = 1;
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(1), static_cast<uint32_t>(mNumHead)};
}
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_softmax));
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]);
ret |= mKernel_softmax->get().setArg(index++, *mTempQK.get());
ret |= mKernel_softmax->get().setArg(index++, *mTempSoftMax.get());
ret |= mKernel_softmax->get().setArg(index++, mSoftMaxRemainChannels);
ret |= mKernel_softmax->get().setArg(index++, mSoftmaxShape);
MNN_CHECK_CL_SUCCESS(ret, "setArg softmax");
mLocalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), 1, 1};
if(localSize == 1){
mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax).first;
}
mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mTempQK.get()))()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mTempSoftMax.get()))()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mSoftMaxRemainChannels), &mSoftMaxRemainChannels});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mSoftmaxShape), &mSoftmaxShape});
mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo);
}
// qk * value
{
std::set<std::string> buildOption;
if(!mIsDecode){
buildOption.emplace("-DOPENCL_PREFILL_ATTENTION");
}
if((mHeadDim % 4) != 0){
buildOption.emplace("-DHEADDIM_LEAVE");
}
mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv", buildOption, inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qkv));
mGlobalWorkSizeQkv = {static_cast<uint32_t>(UP_DIV(seq_len, 4)), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(mHeadDim)};
if(mIsDecode){
mGlobalWorkSizeQkv = {static_cast<uint32_t>(1), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(UP_DIV(mHeadDim, 4))};
}
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]);
ret |= mKernel_qkv->get().setArg(index++, *mTempSoftMax.get());
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(value));
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0]));
ret |= mKernel_qkv->get().setArg(index++, *mPastValue.get());
ret |= mKernel_qkv->get().setArg(index++, seq_len);
ret |= mKernel_qkv->get().setArg(index++, mKv_seq_len);
ret |= mKernel_qkv->get().setArg(index++, mNumHead);
ret |= mKernel_qkv->get().setArg(index++, mHeadDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv");
mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv", mKernel_qkv).first;
mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0]));
mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1]));
mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2]));
mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mTempSoftMax.get()))()});
mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastValue.get()))()});
mQkvUpdateInfo.update_kernel_args.push_back({0, 8, sizeof(mKv_seq_len), &mKv_seq_len});
mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo);
}
mOpenCLBackend->endRecord(mRecording);
return NO_ERROR;
}
ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
#ifdef LOG_VERBOSE
MNN_PRINT("start AttentionBufExecution onExecute !\n");
#endif
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
reallocKVCache();
#ifdef ENABLE_OPENCL_TIME_PROFILER
{
cl::Event event;
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk,
mOpenCLBackend->getOpenCLRuntime(), &event);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event});
}
{
cl::Event event;
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax,
mOpenCLBackend->getOpenCLRuntime(), &event);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event});
}
{
cl::Event event;
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv,
mOpenCLBackend->getOpenCLRuntime(), &event);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event});
}
#else
if(mOpenCLBackend->isUseRecordQueue()){
mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo);
if(mIsDecode){
if(mIsFirstDecode){
mIsFirstDecode = false;
}else{
mPastLength += 1;
mKv_seq_len = mPastLength + 1;
int past_len4 = UP_DIV(mKv_seq_len, 4);
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
mSoftmaxShape[1] = past_len4;
mGlobalWorkSizeQk2 = past_len4;
mQkGlobal_size[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2]));
}
}
#ifdef LOG_VERBOSE
MNN_PRINT("End AttentionBufExecution onExecute... \n");
#endif
return NO_ERROR;
}
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime());
#endif
// decode
if(mIsDecode){
mPastLength += 1;
mKv_seq_len = mPastLength + 1;
int past_len4 = UP_DIV(mKv_seq_len, 4);
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
mSoftmaxShape[1] = past_len4;
cl_int ret = CL_SUCCESS;
mGlobalWorkSizeQk2 = past_len4;
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2]));
ret |= mKernel_qk->get().setArg(2, mGlobalWorkSizeQk2);
ret |= mKernel_qk->get().setArg(10, mKv_seq_len);
ret |= mKernel_softmax->get().setArg(5, mSoftMaxRemainChannels);
ret |= mKernel_softmax->get().setArg(6, mSoftmaxShape);
ret |= mKernel_qkv->get().setArg(8, mKv_seq_len);
MNN_CHECK_CL_SUCCESS(ret, "reset arg for AttentionBufExecution");
}
#ifdef LOG_VERBOSE
MNN_PRINT("end AttentionBufExecution onExecute !\n");
#endif
return NO_ERROR;
}
AttentionBufExecution::AttentionBufExecution(const MNN::Op *op, Backend* backend, bool kv_cahce) : CommonExecution(backend, op) {
mImpl.reset(new AttentionBufImpl(op, backend, kv_cahce));
}
AttentionBufExecution::AttentionBufExecution(std::shared_ptr<AttentionBufImpl> impl, const MNN::Op *op, Backend *backend) : CommonExecution(backend, op), mImpl(impl) {}
ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return mImpl->onResize(backend(), inputs, outputs);
}
ErrorCode AttentionBufExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return mImpl->onExecute(backend(), inputs, outputs);
}
bool AttentionBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
if (nullptr == dst) {
return true;
}
*dst = new AttentionBufExecution(mImpl, op, bn);
return true;
}
class AttentionBufCreator : public OpenCLBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
for (int i = 0; i < inputs.size(); ++i) {
TensorUtils::setTensorSupportPack(inputs[i], false);
}
for (int i = 0; i < outputs.size(); ++i) {
TensorUtils::setTensorSupportPack(outputs[i], false);
}
auto param = op->main_as_AttentionParam();
return new AttentionBufExecution(op, backend, param->kv_cache());
}
};
REGISTER_OPENCL_OP_CREATOR(AttentionBufCreator, OpType_Attention, BUFFER);
} // namespace OpenCL
} // namespace MNN
#endif/* MNN_OPENCL_BUFFER_CLOSED */

View File

@ -0,0 +1,82 @@
//
// AttentionBufExecution.hpp
// MNN
//
// Created by MNN on 2024/04/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifndef AttentionBufExecution_hpp
#define AttentionBufExecution_hpp
#include "backend/opencl/execution/image/CommonExecution.hpp"
namespace MNN {
namespace OpenCL {
class AttentionBufImpl {
public:
AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cache);
~AttentionBufImpl() {
if(mRecording != NULL){
#ifdef MNN_USE_LIB_WRAPPER
clReleaseRecordingQCOM(mRecording);
#endif
}
}
ErrorCode onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
ErrorCode onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
private:
int getLocalSize(int size, int maxGroupSize);
void allocKVCache();
void reallocKVCache();
bool mKVCache;
float mScale;
const int mExpandChunk = 64;
bool mIsDecode = false;
bool mIsFirstDecode = true;
int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0, mSoftMaxRemainChannels = 0;
std::shared_ptr<cl::Buffer> mPastKey, mPastValue, mTempQK, mTempSoftMax;
int mNumHead = 0, mHeadDim = 0, mValueH = 0;
std::shared_ptr<KernelWrap> mKernel_qk;
std::shared_ptr<KernelWrap> mKernel_softmax;
std::shared_ptr<KernelWrap> mKernel_qkv;
std::vector<uint32_t> mGlobalWorkSizeQk{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeQk{1, 1, 1, 1};
std::vector<uint32_t> mGlobalWorkSizeSoftMax{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeSoftMax{1, 1, 1, 1};
std::vector<uint32_t> mGlobalWorkSizeQkv{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeQkv{1, 1, 1, 1};
uint32_t mMaxWorkGroupSize;
OpenCLBackend *mOpenCLBackend;
RecordUpdateInfo mQkUpdateInfo;
RecordUpdateInfo mSoftMaxUpdateInfo;
RecordUpdateInfo mQkvUpdateInfo;
int mGlobalWorkSizeQk2 = 0;
size_t mQkGlobal_size[3];
int mSoftmaxShape[4];
cl_recording_qcom mRecording{NULL};
std::vector<RecordUpdateInfo*> mOpRecordUpdateInfo;
};
class AttentionBufExecution : public CommonExecution {
public:
AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache);
AttentionBufExecution(std::shared_ptr<AttentionBufImpl> impl, const MNN::Op *op, Backend *backend);
virtual ~AttentionBufExecution() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
private:
std::shared_ptr<AttentionBufImpl> mImpl;
};
} // namespace OpenCL
} // namespace MNN
#endif /* AttentionBufExecution_hpp */
#endif /* MNN_OPENCL_BUFFER_CLOSED */

View File

@ -379,7 +379,7 @@ public:
case BinaryOpOperation_NOTEQUAL:
return new BinaryBufExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend);
case BinaryOpOperation_MOD:
return new BinaryBufExecution(inputs, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend);
return new BinaryBufExecution(inputs, "fmod(in0,in1)", op, backend);
default:
break;
}

View File

@ -354,12 +354,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
std::shared_ptr<Tensor> filterBuffer(
Tensor::createDevice<float>({mResource->mOutputChannel, ROUND_UP(mResource->mInputChannel, 4), mResource->mKernelWidth, mResource->mKernelHeight}));
int buffer_size = filterBuffer->elementSize();
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
buffer_size *= sizeof(half_float::half);
} else {
buffer_size *= sizeof(float);
}
int buffer_size = filterBuffer->elementSize() * sizeof(float);
cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
filterBuffer->buffer().device = (uint64_t)(&filterBufferCL);
@ -367,25 +362,10 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
if(ptrCL != nullptr && res == CL_SUCCESS) {
::memset(ptrCL, 0, buffer_size);
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
for(int oc=0; oc<mResource->mOutputChannel; oc++) {
for(int ic=0; ic<mResource->mInputChannel; ic++) {
for(int kh=0; kh<mResource->mKernelHeight; kh++) {
for(int kw=0; kw<mResource->mKernelWidth; kw++) {
int dst_idx = ((oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelHeight + kh)* mResource->mKernelWidth + kw;
int src_idx = ((oc * mResource->mInputChannel + ic) * mResource->mKernelHeight + kh)* mResource->mKernelWidth + kw;
((half_float::half*)ptrCL)[dst_idx] = (half_float::half)(mFilterDataPtr[src_idx]);
}
}
}
}
}else{
const int copy_size = mResource->mKernelWidth * mResource->mKernelHeight * sizeof(float);
for(int oc=0; oc<mResource->mOutputChannel; oc++) {
for(int ic=0; ic<mResource->mInputChannel; ic++) {
::memcpy((float *)ptrCL + (oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelWidth * mResource->mKernelHeight, mFilterDataPtr + (oc * mResource->mInputChannel + ic) * mResource->mKernelWidth * mResource->mKernelHeight, copy_size);
}
const int copy_size = mResource->mKernelWidth * mResource->mKernelHeight * sizeof(float);
for(int oc=0; oc<mResource->mOutputChannel; oc++) {
for(int ic=0; ic<mResource->mInputChannel; ic++) {
::memcpy((float *)ptrCL + (oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelWidth * mResource->mKernelHeight, mFilterDataPtr + (oc * mResource->mInputChannel + ic) * mResource->mKernelWidth * mResource->mKernelHeight, copy_size);
}
}
}else{
@ -397,10 +377,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
bool needTrans = false;
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
needTrans = true;
}
bool needTrans = true;
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), needTrans);
}
}
@ -697,8 +674,7 @@ ErrorCode ConvBufExecution::onExecute(const std::vector<Tensor *> &inputs, const
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"ConvBuf2D", event});
#else
if(mOpenCLBackend->isUseRecordQueue()){
if(mOpenCLBackend->isDevideOpRecord())
mOpenCLBackend->addRecord(mRecording);
mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo);
#ifdef LOG_VERBOSE
MNN_PRINT("End ConvExecution onExecute... \n");
#endif
@ -729,7 +705,7 @@ public:
#ifdef MNN_LOW_MEMORY
{
auto conv2dParams = op->main_as_Convolution2D();
if ((static_cast<OpenCLBackend *>(backend)->getMemory() == BackendConfig::Memory_Low) && (conv2dParams->quanParameter() != nullptr)) {
if (conv2dParams->quanParameter() != nullptr) {
if (((conv2dParams->quanParameter()->type() == 4) ||
(conv2dParams->quanParameter()->type() == 1) ||
(conv2dParams->quanParameter()->type() == 2))) {

View File

@ -19,6 +19,7 @@ struct ConvBufResource {
const Convolution2DCommon *mConv2dCommonParams;
const Convolution2D *mConv2dParams;
std::shared_ptr<cl::Buffer> mKernelBuffer;
std::shared_ptr<cl::Image2D> mKernelImage;
std::shared_ptr<Tensor> dequantScale;
std::shared_ptr<Tensor> dequantOffset;
std::shared_ptr<Tensor> mFilter;
@ -33,6 +34,7 @@ struct ConvBufResource {
bool mConv1x1Opt = false;
bool mConv1x1C8Opt = false;
std::shared_ptr<Execution> mRasterExe;
bool mUseImage = false;
};
class ConvBufCommonExecution {

View File

@ -13,7 +13,7 @@ namespace OpenCL {
// set mDequantScale mDequantOffset mNumQuantBit mFilterDataPtr from mConv2dParams
void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<ConvolutionCommon::Int8Common> & quanCommon) {
quanCommon = ConvolutionCommon::load(mResource->mConv2dParams, this->backend(), false, true);
if ((mOpenCLBackend->getMemory() == BackendConfig::Memory_Low) && (mResource->mConv2dParams->quanParameter() != nullptr)) {
if (mResource->mConv2dParams->quanParameter() != nullptr) {
mLowMemoryFlag = true;
} else {
MNN_ERROR("Conv buf low memory init error.\n");
@ -30,44 +30,34 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<Convoluti
int numAlpha = mResource->mOutputChannel;
// set mDequantScale mDequantOffset
int numAlphaPack = ROUND_UP(numAlpha, 16);
mResource->dequantScale.reset(Tensor::createDevice<float>({numAlphaPack}));
mResource->dequantOffset.reset(Tensor::createDevice<float>({numAlphaPack}));
mResource->dequantScale.reset(Tensor::createDevice<int32_t>({numAlphaPack}));
mResource->dequantOffset.reset(Tensor::createDevice<int32_t>({numAlphaPack}));
mOpenCLBackend->onAcquireBuffer(mResource->dequantScale.get(), Backend::STATIC);
mOpenCLBackend->onAcquireBuffer(mResource->dequantOffset.get(), Backend::STATIC);
cl::Buffer &dequantScaleBuffer = openCLBuffer(mResource->dequantScale.get());
cl::Buffer &dequantOffsetBuffer = openCLBuffer(mResource->dequantOffset.get());
// transfer data from src in cpu to dst in gpu
int bytes = mOpenCLBackend->fpBytes();
int fpBytes = mOpenCLBackend->fpBytes();
cl_int resBias, resScale, resOffset;
void * dequantScaleBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantScaleBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * bytes, nullptr, nullptr, &resScale);
void * dequantOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantOffsetBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * bytes, nullptr, nullptr, &resOffset);
::memset(dequantScaleBufferMap, -1, numAlphaPack * bytes);
::memset(dequantOffsetBufferMap, 0, numAlphaPack * bytes);
void * dequantScaleBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantScaleBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * sizeof(int32_t), nullptr, nullptr, &resScale);
void * dequantOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantOffsetBuffer, true, CL_MAP_WRITE, 0, numAlphaPack * sizeof(int32_t), nullptr, nullptr, &resOffset);
::memset(dequantScaleBufferMap, -1, numAlphaPack * sizeof(int32_t));
::memset(dequantOffsetBufferMap, 0, numAlphaPack * sizeof(int32_t));
if (dequantScaleBufferMap != nullptr && dequantOffsetBufferMap != nullptr && resScale == CL_SUCCESS && resOffset == CL_SUCCESS) {
if (bytes == 2) {
if (quanCommon->asymmetric) {
for (int i = 0; i < numAlpha; ++i) {
((half_float::half *)dequantOffsetBufferMap)[i] = (half_float::half)dequantAlpha[2 * i];
((half_float::half *)dequantScaleBufferMap)[i] = (half_float::half)dequantAlpha[2 * i + 1];
}
} else {
for (int i = 0; i < numAlpha; ++i) {
((half_float::half *)dequantScaleBufferMap)[i] = (half_float::half)dequantAlpha[i];
((half_float::half *)dequantOffsetBufferMap)[i] = 0.0f;
}
if (quanCommon->asymmetric) {
for (int i = 0; i < numAlpha; ++i) {
((float *)dequantOffsetBufferMap)[i] = dequantAlpha[2 * i];
((float *)dequantScaleBufferMap)[i] = dequantAlpha[2 * i + 1];
}
} else {
if (quanCommon->asymmetric) {
for (int i = 0; i < numAlpha; ++i) {
((float *)dequantOffsetBufferMap)[i] = dequantAlpha[2 * i];
((float *)dequantScaleBufferMap)[i] = dequantAlpha[2 * i + 1];
}
} else {
for (int i = 0; i < numAlpha; ++i) {
((float *)dequantScaleBufferMap)[i] = dequantAlpha[i];
((float *)dequantOffsetBufferMap)[i] = 0.0f;
}
for (int i = 0; i < numAlpha; ++i) {
((float *)dequantScaleBufferMap)[i] = dequantAlpha[i];
((float *)dequantOffsetBufferMap)[i] = 0.0f;
}
}
} else {
@ -82,7 +72,7 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<Convoluti
// set mKernelBuffer for the 1x1 kernels
void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin, void * filterDataPtr, std::shared_ptr<ConvolutionCommon::Int8Common> & quanCommon) {
cl_int res;
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>({ROUND_UP(mResource->mOutputChannel, 8)/*Cout pack set to max 8*/, ROUND_UP(mResource->mInputChannel, packCin), mResource->mKernelWidth, mResource->mKernelHeight}));
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>({ROUND_UP(mResource->mOutputChannel, packCout), ROUND_UP(mResource->mInputChannel, packCin), mResource->mKernelWidth, mResource->mKernelHeight}));
size_t buffer_size = filterBuffer->usize() / sizeof(float);
float *dequantAlpha = quanCommon->alpha.get();
// shared part for all cases
@ -93,43 +83,66 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin,
// int4 case
buffer_size /= 2;
} else {/* More types to be supported. */}
mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
auto kernelBufferPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
if(kernelBufferPtr != nullptr && res == CL_SUCCESS){
::memset(kernelBufferPtr, 0, buffer_size);
// Use Image load weights
void *mapPtr = nullptr;
size_t row_pitch;
size_t slice_pitch;
if(UP_DIV(mResource->mInputChannel, packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, packCout) <= 16384){
mResource->mUseImage = true;
}
if(mResource->mUseImage) {
if(mNumQuantBit == 4){
packCin *= 2;
}
size_t w = ROUND_UP(mResource->mOutputChannel, packCout);
size_t h = UP_DIV(mResource->mInputChannel, packCin);
mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_FLOAT), w, h, 0, nullptr, &res));
if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) {
MNN_ERROR("Alloc Image %d x %d error, code:%d \n", w, h, res);
}
mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapImage(*(mResource->mKernelImage.get()), true, CL_MAP_WRITE, {0, 0, 0}, {w, h, 1}, &row_pitch, &slice_pitch, nullptr, nullptr, &res);
if(mNumQuantBit == 4){
row_pitch *= 2;
}
} else{
mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
row_pitch = ROUND_UP(mResource->mOutputChannel, packCout) * packCin;
}
if(mapPtr != nullptr && res == CL_SUCCESS){
for(int o = 0; o < mResource->mOutputChannel; o++){
float zero = 0;
if(quanCommon->asymmetric){
zero = (-dequantAlpha[2 * o + 1])/dequantAlpha[2 * o];
}
int i = 0;
for(; i < mResource->mInputChannel; i++){
int bufferIdx = (o/packCout) * packCin*packCout + (i/packCin)*packCin*ROUND_UP(mResource->mOutputChannel, packCout) + (o%packCout)*packCin + (i%packCin);//(Ci/packCin Co/packCoutpackCout, packCin)
for(; i < mResource->mInputChannel ; i++){
int bufferIdx = (i/packCin) * row_pitch + o*packCin + (i%packCin);//(Ci/packCin Co/packCout * packCout * packCin)
int filterIdx = o*mResource->mInputChannel + i;
if (mNumQuantBit == 8) {
// int8 case
((int8_t *)kernelBufferPtr)[bufferIdx] = (int8_t)(((int8_t *)filterDataPtr)[filterIdx]);
((int8_t *)mapPtr)[bufferIdx] = (int8_t)(((int8_t *)filterDataPtr)[filterIdx]);
} else if (mNumQuantBit == 4){
// int4 case
if (bufferIdx % 2 == 0) {
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)((((int8_t *)filterDataPtr)[filterIdx] + 8) * 16);
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)((((int8_t *)filterDataPtr)[filterIdx] + 8) * 16);
} else {
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)(((int8_t *)filterDataPtr)[filterIdx] + 8);
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)(((int8_t *)filterDataPtr)[filterIdx] + 8);
}
} else {/* More types to be supported. */}
}
for(; i < ROUND_UP(mResource->mInputChannel, 4); i++){
int bufferIdx = (o/packCout) * packCin*packCout + (i/packCin)*packCin*ROUND_UP(mResource->mOutputChannel, packCout) + (i%packCin)*packCout + (o%packCout);//(Ci/packCin Co/packCout, packCin packCout)
for(; i < ROUND_UP(mResource->mInputChannel, packCin); i++){
int bufferIdx = (i/packCin) * row_pitch + o*packCin + (i%packCin);//(Ci/packCin Co/packCout * packCout * packCin)
if (mNumQuantBit == 8) {
// int8 case
((int8_t *)kernelBufferPtr)[bufferIdx] = (int8_t)(zero);
((int8_t *)mapPtr)[bufferIdx] = (int8_t)(zero);
} else if (mNumQuantBit == 4){
// int4 case
if (bufferIdx % 2 == 0) {
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)((zero + 8) * 16);
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)((zero + 8) * 16);
} else {
((uint8_t *)kernelBufferPtr)[bufferIdx / 2] += (uint8_t)(zero + 8);
((uint8_t *)mapPtr)[bufferIdx / 2] += (uint8_t)(zero + 8);
}
}
}
@ -138,7 +151,11 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin,
MNN_ERROR("set1x1WeightLowMemory: Map error ptrCL == nullptr \n");
MNN_ASSERT(false);
}
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelBuffer.get()), kernelBufferPtr);
if(mResource->mUseImage){
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelImage.get()), mapPtr);
} else{
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelBuffer.get()), mapPtr);
}
}
// set mFilter for the general kernels
void ConvBufLowMemoryExecution::setGeneralWeightLowMemory(void* filterDataPtr, std::shared_ptr<ConvolutionCommon::Int8Common> & quanCommon) {
@ -321,13 +338,11 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
const int width = outputShape.at(2);
const int inputChannelBlocks = UP_DIV(inputChannels, 4);
const int outputChannelBlocks = UP_DIV(outChannel, 4);
std::string kernelname = "gemm_conv_buf";
int global_x = outputChannelBlocks;
int global_y = batch * height;
const int total_kernel = 3;
std::string kernelName[total_kernel] = {"gemm_conv_c1_buf", "gemm_conv_c2_buf", "gemm_conv_c4_buf",};
int itemC[total_kernel] = {1, 2, 4};
const int total_kernel = 5;
std::string kernelName[total_kernel] = {"gemm_conv_c1_buf", "gemm_conv_c2_buf", "gemm_conv_c4_buf", "gemm_conv_c1_image", "gemm_conv_c2_image"};
int itemC[total_kernel] = {1, 2, 4, 1, 2};
int actual_kernel = total_kernel;
std::shared_ptr<KernelWrap> kernel[total_kernel];
std::vector<uint32_t> globalWorkSize[total_kernel];
@ -337,12 +352,27 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
if(width == 1 && height == 1){
buildOption.emplace("-DWIDTH_HEIGHT_1");
}
if(inputChannels % 16 != 0){
buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
} else if (mResource->mUseImage && mNumQuantBit == 4 && inputChannels % 32 != 0) {
// Image weight-int4 use load32
buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
}
std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel);
for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", kernelName[knl_idx], buildOption);
if(batch > 1){
global_y = UP_DIV(batch, 2) * height;
buildOption.emplace("-DBACTH_BLOCK2");
info += "_BATCH_BLOCK2";
}
int knl_idx = 0;
actual_kernel = 3;
if(mResource->mUseImage){
knl_idx = 3;
actual_kernel = total_kernel;
}
for (; knl_idx < actual_kernel; knl_idx++) {
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[knl_idx], buildOption);
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outChannel, itemC[knl_idx]) * width), static_cast<uint32_t>(global_y)};
@ -351,7 +381,11 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]);
ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]);
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(input));
ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelBuffer.get());
if(mResource->mUseImage){
ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelImage.get());
}else{
ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelBuffer.get());
}
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->dequantScale.get()));
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->dequantOffset.get()));
ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get()));
@ -361,7 +395,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
ret |= kernel[knl_idx]->get().setArg(idx++, static_cast<int>(batch));
ret |= kernel[knl_idx]->get().setArg(idx++, static_cast<int>(height));
ret |= kernel[knl_idx]->get().setArg(idx++, static_cast<int>(width));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv_buf Kernel Select");
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf Kernel Select");
std::pair<std::vector<uint32_t>, int> retTune;
retTune = gws2dLwsTune(kernel[knl_idx], globalWorkSize[knl_idx], kernelName[knl_idx] + info, maxWorkGroupSize);
if(min_cost.first > retTune.second) {
@ -374,14 +408,18 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", kernelName[min_index], buildOption);
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[min_index], buildOption);
//MNN_PRINT("Kernel is %d.\n", min_index);
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input));
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get());
if(mResource->mUseImage){
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get());
}else{
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get());
}
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScale.get()));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantOffset.get()));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get()));
@ -391,7 +429,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(batch));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(height));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(width));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv_buf");
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf");
mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]};
unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]};
@ -422,7 +460,7 @@ ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector<Tensor *>
// prepare mDequantScale mDequantOffset mFilterDataPtr
getInfoFromOpLowMemory(quanCommon);
//select opt conv method
if (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1) {
if (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1 && conv2dCommonParams->padX() == 0 && conv2dCommonParams->padY() == 0 && conv2dCommonParams->dilateX() == 1 && conv2dCommonParams->dilateY() == 1) {
set1x1WeightLowMemory(4, 16, mFilterDataPtr, quanCommon);
mResource->mConv1x1Opt = true;
}else {

View File

@ -48,24 +48,13 @@ DeconvBufExecution::DeconvBufExecution(const std::vector<Tensor *> &inputs, cons
std::shared_ptr<Tensor> filterBuffer(
Tensor::createDevice<float>({outputChannel, inputChannel, kernelHeight, kernelWidth}));
int buffer_size = filterBuffer->elementSize();
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
buffer_size *= sizeof(half_float::half);
} else {
buffer_size *= sizeof(float);
}
size_t buffer_size = filterBuffer->elementSize() * sizeof(float);
cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_ONLY | CL_MEM_ALLOC_HOST_PTR, buffer_size);
filterBuffer->buffer().device = (uint64_t)(&filterBufferCL);
cl_int error;
auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
if(ptrCL != nullptr && error == CL_SUCCESS){
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
for(int i=0; i<filterBuffer->elementSize(); i++) {
((half_float::half*)ptrCL)[i] = (half_float::half)(filterDataPtrTransformed[i]);
}
}else{
::memcpy(ptrCL, filterDataPtrTransformed.data(), filterBuffer->size());
}
::memcpy(ptrCL, filterDataPtrTransformed.data(), filterBuffer->size());
}else{
MNN_ERROR("Map error ptrCL == nullptr \n");
}
@ -75,10 +64,7 @@ DeconvBufExecution::DeconvBufExecution(const std::vector<Tensor *> &inputs, cons
mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
bool needTrans = false;
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
needTrans = true;
}
bool needTrans = true;
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), needTrans);
mResource->mBuildOptions.emplace("-DBIAS");
if (conv2dCommonParams->relu() == true) {

View File

@ -39,24 +39,13 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector<Tensor *>
mResource->mFilter.reset(Tensor::createDevice<float>({1, ROUND_UP(filterImageShape[1], 2)/*for kernel C8 read*/, 1, 4 * filterImageShape[0]}));
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>(filterShape));
int buffer_size = filterBuffer->elementSize();
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
buffer_size *= sizeof(half_float::half);
} else {
buffer_size *= sizeof(float);
}
size_t buffer_size = filterBuffer->elementSize() * sizeof(float);
cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
filterBuffer->buffer().device = (uint64_t)(&filterBufferCL);
cl_int error;
auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
if(ptrCL != nullptr && error == CL_SUCCESS){
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
for (int i = 0; i < filterBuffer->elementSize(); i++) {
((half_float::half *)ptrCL)[i] = (half_float::half)(filterDataPtr[i]);
}
} else {
::memcpy(ptrCL, filterDataPtr, filterBuffer->size());
}
::memcpy(ptrCL, filterDataPtr, filterBuffer->size());
}else{
MNN_ERROR("Map error ptrCL == nullptr \n");
}
@ -65,10 +54,7 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector<Tensor *>
mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC);
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
bool needTrans = false;
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
needTrans = true;
}
bool needTrans = true;
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::DW_CONV2D_FILTER, mResource->mFilter.get(), needTrans);
if (mResource->mConv2dCommonParams->relu() == true) {

View File

@ -94,7 +94,7 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector<Tensor *> &inputs, c
Tensor *input = inputs[0];
Tensor *output = outputs[0];
auto runtime = ((OpenCLBackend *)backend())->getOpenCLRuntime();
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize) / 4;
std::vector<int> inputShape = tensorShapeFormat(input);
std::vector<int> outputShape = tensorShapeFormat(output);
@ -114,6 +114,14 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector<Tensor *> &inputs, c
inner_size *= inputs.at(0)->length(i);
}
if (group_ > 1) {
outter_size = inputs[0]->length(0) * group_;
inner_size = 1;
for (int i = 1; i < rank; i++) {
inner_size *= inputs[0]->length(i);
}
inner_size /= group_;
}
std::set<std::string> buildOptions;
if(RMSNorm){
@ -150,6 +158,111 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector<Tensor *> &inputs, c
mGWS = {static_cast<uint32_t>(local_size),
static_cast<uint32_t>(1),
static_cast<uint32_t>(inputBatch)};
} else if(inner_size == inputWidth * inputHeight * inputChannels / group_ && outter_size == inputBatch * group_){
mUnits.clear();
mUnits.resize(3);
std::vector<int> inputShape = tensorShapeFormat(inputs[0]);
int inputWH[] = {inputShape[2], inputShape[1]};
int region[] = {inputShape[0], UP_DIV(inputShape[3], 4), inputShape[1], inputShape[2]};
mInputPlain = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{inputShape[0], inputShape[3], ROUND_UP(inputShape[1] * inputShape[2], 4), 1}, Tensor::CAFFE));
mOpenCLBackend->onAcquireBuffer(mInputPlain.get(), Backend::DYNAMIC);
mOutputPlain = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{inputShape[0], inputShape[3], ROUND_UP(inputShape[1] * inputShape[2], 4), 1}, Tensor::CAFFE));
mOpenCLBackend->onAcquireBuffer(mOutputPlain.get(), Backend::DYNAMIC);
// convert nc4hw4 to nchw
{
auto &unit = mUnits[0];
unit.kernel = runtime->buildKernel("buffer_convert_buf", "nc4hw4_buffer_to_nchw_buffer", {}, inputs[0], outputs[0]);
mGWS = {(uint32_t)(UP_DIV(region[3] * region[1], 16) * 16),
(uint32_t)(UP_DIV(region[2] * region[0], 16) * 16)};
mLWS = {16, 16};
unit.globalWorkSize = {mGWS[0], mGWS[1]};
unit.localWorkSize = {mLWS[0], mLWS[1]};
int global_dim0 = region[3] * region[1];
int global_dim1 = region[2] * region[0];
//MNN_CHECK_CL_SUCCESS
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, global_dim0);
ret |= unit.kernel->get().setArg(idx++, global_dim1);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get()));
ret |= unit.kernel->get().setArg(idx++, inputWH[1]);
ret |= unit.kernel->get().setArg(idx++, inputWH[0]);
ret |= unit.kernel->get().setArg(idx++, inputShape[3]);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input));
MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, convert nc4hw4 to nchw");
mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS);
}
// do group layernorm
{
auto &unit = mUnits[1];
kernelName = "layernorm_plain_buf";
local_size = getLocalSize(UP_DIV(inner_size, 4), MaxLocalSize);
buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size));
unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions);
mGWS = {static_cast<uint32_t>(local_size),
static_cast<uint32_t>(1),
static_cast<uint32_t>(outter_size)};
mLWS = {static_cast<uint32_t>(local_size), 1, 1};
unit.globalWorkSize = {mGWS[0], mGWS[1], mGWS[2]};
unit.localWorkSize = {mLWS[0], mLWS[1], mLWS[2]};
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, mGWS[0]);
ret |= unit.kernel->get().setArg(idx++, mGWS[1]);
ret |= unit.kernel->get().setArg(idx++, mGWS[2]);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get()));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get()));
ret |= unit.kernel->get().setArg(idx++, static_cast<int32_t>(inner_size));
ret |= unit.kernel->get().setArg(idx++, static_cast<int32_t>(outter_size));
if(has_gamma_beta_){
ret |= unit.kernel->get().setArg(idx++, *mGammaBuffer.get());
ret |= unit.kernel->get().setArg(idx++, *mBetaBuffer.get());
}
ret |= unit.kernel->get().setArg(idx++, epsilon_);
MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, do group layernorm");
mOpenCLBackend->recordKernel3d(unit.kernel, mGWS, mLWS);
}
// convert nchw to nc4hw4
{
auto &unit = mUnits[2];
unit.kernel = runtime->buildKernel("buffer_convert_buf", "nchw_buffer_to_nc4hw4_buffer", {}, inputs[0], outputs[0]);
mLWS = {16, 16};
mGWS = {(uint32_t)UP_DIV(region[3] * region[1], 16) * 16,
(uint32_t)UP_DIV(region[2] * region[0], 16) * 16};
unit.globalWorkSize = {mGWS[0], mGWS[1]};
unit.localWorkSize = {mLWS[0], mLWS[1]};
int global_dim0 = region[3] * region[1];
int global_dim1 = region[2] * region[0];
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, global_dim0);
ret |= unit.kernel->get().setArg(idx++, global_dim1);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get()));
ret |= unit.kernel->get().setArg(idx++, inputWH[1]);
ret |= unit.kernel->get().setArg(idx++, inputWH[0]);
ret |= unit.kernel->get().setArg(idx++, inputShape[3]);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output));
MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, convert nchw to nc4hw4");
mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS);
}
mOpenCLBackend->onReleaseBuffer(mInputPlain.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mOutputPlain.get(), Backend::DYNAMIC);
return NO_ERROR;
}
mLWS = {static_cast<uint32_t>(local_size), 1, 1};
@ -189,10 +302,6 @@ public:
TensorUtils::setTensorSupportPack(outputs[i], false);
}
const auto* layer_norm_param = op->main_as_LayerNorm();
int group = layer_norm_param->group();
if(group > 1){
return nullptr;
}
return new LayerNormBufExecution(inputs, op, backend);
}
};

View File

@ -35,6 +35,7 @@ private:
std::shared_ptr<cl::Buffer> mGammaBuffer;
std::shared_ptr<cl::Buffer> mBetaBuffer;
std::shared_ptr<Tensor> mInputPlain, mOutputPlain;
bool has_gamma_beta_ = false;
uint32_t mMaxWorkGroupSize;
};

View File

@ -208,7 +208,8 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector<Tensor *> &inp
const int Width = Shape.at(2);
const int Height = Shape.at(1);
const int Batch = Shape.at(0);
mTmpTensors[i] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, ROUND_UP(Height, 4), ROUND_UP(Width, 4)}, Tensor::CAFFE));
mTmpTensors[i] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}, Tensor::CAFFE));
mOpenCLBackend->onAcquireBuffer(mTmpTensors[i].get(), Backend::DYNAMIC);
Unit unit;

View File

@ -54,7 +54,9 @@ ErrorCode SoftmaxBufExecution::onEncode(const std::vector<Tensor *> &inputs, con
const auto dims = input->buffer().dimensions;
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize) / 4;
int inside = 1;
int outside = 1;
int channel = 1;

Some files were not shown because too many files have changed in this diff Show More