[MNN:Sync] Sync Internal 2.8.1

This commit is contained in:
zhaode.wzd 2023-12-27 17:26:44 +08:00
parent 1a5609b861
commit 3b978d9d16
282 changed files with 9445 additions and 4681 deletions

View File

@ -489,6 +489,7 @@ IF(MNN_COREML)
IF(MNN_SEP_BUILD) IF(MNN_SEP_BUILD)
list(APPEND MNN_DEPS MNNCoreML) list(APPEND MNN_DEPS MNNCoreML)
list(APPEND MNN_EXTRA_DEPENDS MNNCoreML)
ELSE() ELSE()
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNCoreML>) list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNCoreML>)
ENDIF() ENDIF()
@ -552,6 +553,7 @@ IF(MNN_OPENCL)
IF(MNN_SEP_BUILD) IF(MNN_SEP_BUILD)
list(APPEND MNN_DEPS MNN_CL) list(APPEND MNN_DEPS MNN_CL)
ELSE() ELSE()
add_definitions(-DMNN_OPENCL_ENABLED=1)
list(APPEND MNN_TARGETS MNN_CL) list(APPEND MNN_TARGETS MNN_CL)
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_CL>) list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_CL>)
list(APPEND MNN_EXTRA_DEPENDS ${MNN_OCL_LIBS}) list(APPEND MNN_EXTRA_DEPENDS ${MNN_OCL_LIBS})

82
MNN_Render.podspec Normal file
View File

@ -0,0 +1,82 @@
Pod::Spec.new do |s|
s.name = "MNN"
s.version = "2.2.0"
s.summary = "MNN"
s.description = <<-DESC
MNN is a lightweight deep neural network inference framework. It loads models and do inference on devices.
DESC
s.homepage = "https://github.com/alibaba/MNN"
s.license = {
:type => 'Apache License, Version 2.0',
:text => <<-LICENSE
Copyright © 2018, Alibaba Group Holding Limited
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
LICENSE
}
s.author = { "MNN" => "MNN@alibaba-inc.com" }
s.platform = :ios
s.ios.deployment_target = '8.0'
s.requires_arc = true
#s.source = { :git => "git@github.com:alibaba/MNN.git", :branch => 'master' }
s.source = {:git => "/Users/zhang/Development/AliNNPrivate/",:branch=> 'head'}
s.frameworks = 'Metal', 'Accelerate', 'CoreML'
s.library = 'c++'
s.source_files = \
'include/MNN/*.{h,hpp}',\
'include/MNN/expr/*.{h,hpp}',\
'schema/current/*.{h}',\
'3rd_party/flatbuffers/include/flatbuffers/*.{h}',\
'source/internal/logging/*.{hpp,cpp}',\
'source/internal/logging/ios/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/internal/logging/aliyun-log-c-sdk/src/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/core/**/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/common/**/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/utils/**/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/geometry/**/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/cv/**/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/math/**/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'source/shape/*.{h,c,m,mm,cc,hpp,cpp}',\
'source/shape/render/*.{h,c,m,mm,cc,hpp,cpp}',\
#'source/backend/arm82/*.{h,c,m,mm,cc,S,hpp,cpp}',\
#'source/backend/arm82/asm/**/*.{h,c,m,mm,cc,S,hpp,cpp}',\
'source/backend/cpu/*.{h,c,m,mm,cc,S,hpp,cpp}',\
'source/backend/cpu/render/*.{h,c,m,mm,cc,S,hpp,cpp}',\
'source/backend/cpu/bf16/*.{h,c,m,mm,cc,S,hpp,cpp}',\
'source/backend/cpu/arm/**/*.{h,c,m,mm,cc,S,hpp,cpp}',\
'source/backend/cpu/compute/*.{h,c,m,mm,cc,S,hpp,cpp}',\
'source/backend/metal/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'source/backend/metal/render/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'source/backend/coreml/backend/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'source/backend/coreml/execution/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'source/backend/coreml/mlmodel/src/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'express/**/*.{hpp,cpp}',\
'tools/cv/include/**/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'tools/cv/source/imgproc/*.{h,c,m,mm,cc,hpp,cpp,metal}',\
'tools/cv/source/calib3d/*.{h,c,m,mm,cc,hpp,cpp,metal}'
s.header_mappings_dir = 'include'
s.subspec 'cv' do |sp|
sp.source_files = 'tools/cv/include/**/*.hpp'
sp.header_mappings_dir = 'tools/cv/include'
sp.xcconfig = { 'ALWAYS_SEARCH_USER_PATHS' => 'NO' }
end
s.compiler_flags = '-arch arm64 -march=armv8.2-a+simd+fp16'
s.pod_target_xcconfig = {'METAL_LIBRARY_FILE_BASE' => 'mnn', 'HEADER_SEARCH_PATHS' => '"$(PODS_TARGET_SRCROOT)/include" "$(PODS_TARGET_SRCROOT)/3rd_party/flatbuffers/include" "$(PODS_TARGET_SRCROOT)/source" "$(PODS_TARGET_SRCROOT)/3rd_party/half" "$(PODS_TARGET_SRCROOT)/source/backend/coreml/mlmodel/include" "$(PODS_TARGET_SRCROOT)/tools/cv/include"', 'GCC_PREPROCESSOR_DEFINITIONS' => '$(inherited) MNN_CODEGEN_REGISTER=1 MNN_SUPPORT_TFLITE_QUAN=1 MNN_METAL_ENABLED=1 MNN_METAL_FULL_PRECISION=1 MNN_SUPPORT_RENDER=1 MNN_SUPPORT_BF16=1 MNN_COREML_ENABLED=1 USE_LZ4_FLAG=1 MNN_INTERNAL_ENABLED=1 MNN_USE_SPARSE_COMPUTE=1'}
s.user_target_xcconfig = { 'OTHER_LDFLAGS' => '-force_load $(BUILD_DIR)/$(CONFIGURATION)$(EFFECTIVE_PLATFORM_NAME)/MNN/libMNN.a', 'HEADER_SEARCH_PATHS' => '"$(PODS_TARGET_SRCROOT)/include"' }
end

View File

@ -55,14 +55,10 @@
- `checkInvalidValue.out` 检测输出目录里的数据 - `checkInvalidValue.out` 检测输出目录里的数据
- `timeProfile.out` 测试模型在指定后端上执行的时间,并获取每层的执行时间占比 - `timeProfile.out` 测试模型在指定后端上执行的时间,并获取每层的执行时间占比
- `testTrain.out` 测试训练功能 - `testTrain.out` 测试训练功能
- `aoa_nlu_encoder.out` 测试NLU编码
- `aoa_nlu_decoder1.out` 测试NLU解码1
- `aoa_nlu_decoder2.out` 测试NLU解码2
- `checkDir.out` 测试两个文件夹是否一致 - `checkDir.out` 测试两个文件夹是否一致
- `checkFile.out` 测试两个文件是否一致 - `checkFile.out` 测试两个文件是否一致
- `winogradExample.out` winograd示例 - `winogradExample.out` winograd示例
- `winogradGenerateGLSL.out` winograd生成GLSL - `fuseTest` 测试 GPU 自定义算子的功能,目前仅支持 Vulkan Buffer 模式
- `winogradGenerateCL.out` winograd生成CL
## Benchmark工具 ## Benchmark工具
- 相关编译选项 - 相关编译选项
- `MNN_BUILD_BENCHMARK` 是否编译Benchmark工具 - `MNN_BUILD_BENCHMARK` 是否编译Benchmark工具

View File

@ -2195,6 +2195,25 @@ array([[[[0., 1.]],
[[6., 7.]]]], dtype=float32) [[6., 7.]]]], dtype=float32)
``` ```
---
### `reverse(x, axis)`
在输入x变量在axis[0]维度进行翻转
参数:
- `x : var_like` 输入变量
- `axis : var_like` 输入变量
返回:反转序列的值
返回类型:`Var`
示例:
```python
>>> expr.reverse(expr.range(-4., 4., 1.), [0])
array([ 3., 2., 1., 0., -1., -2., -3., -4.], dtype=float32)
```
--- ---
### `reverse_sequence(x, y, batch_dim, seq_dim)` ### `reverse_sequence(x, y, batch_dim, seq_dim)`
沿着batch_dim维度对x进行切片并反转维度seq_dim上的y[i]元素 沿着batch_dim维度对x进行切片并反转维度seq_dim上的y[i]元素

View File

@ -457,3 +457,14 @@ Matrix:
0.0000000 0.0000000 1.0000000 0.0000000 0.0000000 1.0000000
``` ```
## fuseTest
### 功能
测试 GPU 自定义算子的功能,目前仅支持 Vulkan Buffer 模式
### 参数
`Usage: ./fuseTest user.spirv config.json`
- `user.spirv:str`SPIRV文件路径可以用 glslangValidator -V user.comp -o user.spirv 编译获得
- `config.json:str`: 配置文件路径
### 示例
```bash
$ ./fuseTest user.spirv user.json

View File

@ -120,7 +120,7 @@ Executor::Requirement Executor::getRequirement(Expr* expr) const {
return req; return req;
} }
for (int i = 0; i < inputSize; ++i) { for (int i = 0; i < inputSize; ++i) {
req.contentNeedContent[i] = OpCommonUtils::opNeedContent(op->type(), i); req.contentNeedContent[i] = OpCommonUtils::opNeedContent(op, i);
req.shapeNeedContent[i] = false; req.shapeNeedContent[i] = false;
} }
auto needIndexId = SizeComputer::needInputContent(op, inputSize); auto needIndexId = SizeComputer::needInputContent(op, inputSize);

View File

@ -192,6 +192,17 @@ EXPRP Expr::create(std::shared_ptr<BufferStorage> extra, std::vector<VARP>&& inp
EXPRP expr(new Expr(outputSize)); EXPRP expr(new Expr(outputSize));
expr->mStorage = extra; expr->mStorage = extra;
expr->mOp = flatbuffers::GetRoot<Op>(extra->buffer()); expr->mOp = flatbuffers::GetRoot<Op>(extra->buffer());
switch (expr->mOp->type()) {
case OpType_Const:
expr->mType = VARP::CONSTANT;
break;
case OpType_TrainableParam:
expr->mType = VARP::TRAINABLE;
break;
default:
expr->mType = VARP::INPUT;
break;
}
expr->mInputs = std::move(inputs); expr->mInputs = std::move(inputs);
auto exe = ExecutorScope::Current(); auto exe = ExecutorScope::Current();
expr->mInside->mReq = exe->getRequirement(expr.get()); expr->mInside->mReq = exe->getRequirement(expr.get());

View File

@ -626,6 +626,13 @@ VARP _ChannelShuffle(VARP x, int group) {
x = _Convert(x, NC4HW4); x = _Convert(x, NC4HW4);
return x; return x;
} }
VARP _Reverse(VARP x, VARP axis) {
std::unique_ptr<MNN::OpT> op(new MNN::OpT);
op->type = MNN::OpType_Reverse;
return (Variable::create(Expr::create(op.get(), {x, axis})));
}
VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim) { VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim) {
std::unique_ptr<OpT> op(new OpT); std::unique_ptr<OpT> op(new OpT);
op->type = OpType_ReverseSequence; op->type = OpType_ReverseSequence;
@ -1710,19 +1717,10 @@ VARP _GridSample(VARP input, VARP grid, InterpolationMethod mode, GridSamplePadd
} }
VARP _FloatToInt8(VARP x, VARP scale, char minValue/*For future*/, char maxValue/*For future*/) { VARP _FloatToInt8(VARP x, VARP scale, char minValue/*For future*/, char maxValue/*For future*/) {
auto xInfo = x->getInfo();
auto scaleInfo = scale->getInfo(); auto scaleInfo = scale->getInfo();
auto scalePtr = scale->readMap<float>(); auto scalePtr = scale->readMap<float>();
if (nullptr == scalePtr || nullptr == xInfo || nullptr == scaleInfo) { if (nullptr == scalePtr || nullptr == scaleInfo) {
MNN_ERROR("Error for FloatToInt8 because var not ready\n"); MNN_ERROR("Error for FloatToInt8 because scale not ready\n");
return nullptr;
}
if (xInfo->order != NC4HW4 || xInfo->type.code != halide_type_float) {
MNN_ERROR("Not Support Input for FloatToInt8 because var not NC4HW4 or not float\n");
return nullptr;
}
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
MNN_ERROR("Scale's size not match input's channel: %d - %d\n", scaleInfo->size, xInfo->dim[1]);
return nullptr; return nullptr;
} }
std::unique_ptr<OpT> op(new OpT); std::unique_ptr<OpT> op(new OpT);
@ -1735,21 +1733,12 @@ VARP _FloatToInt8(VARP x, VARP scale, char minValue/*For future*/, char maxValue
} }
VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t zeroPoint) { VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t zeroPoint) {
auto xInfo = x->getInfo();
auto scaleInfo = scale->getInfo(); auto scaleInfo = scale->getInfo();
auto scalePtr = scale->readMap<float>(); auto scalePtr = scale->readMap<float>();
if (nullptr == scalePtr || nullptr == xInfo || nullptr == scaleInfo) { if (nullptr == scalePtr || nullptr == scaleInfo) {
MNN_ERROR("Error for FloatToInt8 because var not ready\n"); MNN_ERROR("Error for FloatToInt8 because var not ready\n");
return nullptr; return nullptr;
} }
if (xInfo->order != NC4HW4 || xInfo->type.code != halide_type_float) {
MNN_ERROR("Not Support Input for FloatToInt8 because var not NC4HW4 or not float\n");
return nullptr;
}
if ((scaleInfo->size != xInfo->dim[1]) && (scaleInfo->size != 1)) {
MNN_ERROR("Scale's size not match input's channel: %d - %d\n", scaleInfo->size, xInfo->dim[1]);
return nullptr;
}
std::unique_ptr<OpT> op(new OpT); std::unique_ptr<OpT> op(new OpT);
op->type = OpType_FloatToInt8; op->type = OpType_FloatToInt8;
op->main.type = OpParameter_QuantizedFloatParam; op->main.type = OpParameter_QuantizedFloatParam;

View File

@ -58,6 +58,10 @@ ExprModule::ExprModule(EXPRP expr) {
break; break;
} }
} }
// TODO: Optimize the logic
if (!mExpr->mCanDecompose) {
ExecutorScope::Current()->setLazyComputeMode(Executor::LAZY_CONTENT);
}
} }
std::vector<VARP> ExprModule::onForward(const std::vector<VARP>& inputs) { std::vector<VARP> ExprModule::onForward(const std::vector<VARP>& inputs) {
@ -72,6 +76,14 @@ std::vector<VARP> ExprModule::onForward(const std::vector<VARP>& inputs) {
std::vector<VARP> outputVars; std::vector<VARP> outputVars;
auto newExpr = Expr::create(mExpr->extra(), std::move(tempInputs), mExpr->outputSize()); auto newExpr = Expr::create(mExpr->extra(), std::move(tempInputs), mExpr->outputSize());
newExpr->setName(mExpr->name()); newExpr->setName(mExpr->name());
if (!mExpr->mCanDecompose) {
// Set tensor shape from net
newExpr->mCanDecompose = false;
for (int index = 0; index < mExpr->outputSize(); ++index) {
TensorUtils::copyShape(mExpr->inside()->mOutputTensors[index], newExpr->inside()->mOutputTensors[index], true, true);
Utils::copyTensorToInfo(newExpr->inside()->mOutputInfos.data() + index, newExpr->inside()->mOutputTensors[index]);
}
}
for (int i = 0; i < mExpr->outputSize(); ++i) { for (int i = 0; i < mExpr->outputSize(); ++i) {
outputVars.emplace_back(Variable::create(newExpr, i)); outputVars.emplace_back(Variable::create(newExpr, i));
} }
@ -562,6 +574,23 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
config = &defaultConfig; config = &defaultConfig;
} }
auto subGraphs = net->subgraphs(); auto subGraphs = net->subgraphs();
if (config->dynamic) {
// TODO: Support subgraph
if (nullptr == subGraphs) {
auto varMap = MNN::Express::Variable::loadMap(buffer, length);
std::vector<MNN::Express::VARP> inputsVar(inputs.size());
for (int i=0; i<inputs.size(); ++i) {
inputsVar[i] = varMap[inputs[i]];
}
std::vector<MNN::Express::VARP> outputsVar(outputs.size());
for (int i=0; i<outputs.size(); ++i) {
outputsVar[i] = varMap[outputs[i]];
}
return extract(inputsVar, outputsVar, false);
} else {
MNN_ERROR("Don't support subgraph for dynamic load, turn back to static load\n");
}
}
std::map<std::string, SubGraph> subGraphMap; std::map<std::string, SubGraph> subGraphMap;
_createSubGraph(net, rtMgr, config, subGraphMap); _createSubGraph(net, rtMgr, config, subGraphMap);
std::shared_ptr<BufferStorage> bufferStorage(new BufferStorage); std::shared_ptr<BufferStorage> bufferStorage(new BufferStorage);

View File

@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR(x) STR_IMP(x) #define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 2 #define MNN_VERSION_MAJOR 2
#define MNN_VERSION_MINOR 8 #define MNN_VERSION_MINOR 8
#define MNN_VERSION_PATCH 0 #define MNN_VERSION_PATCH 1
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */ #endif /* MNNDefine_h */

View File

@ -24,6 +24,15 @@ struct MNNVulkanContext {
uint32_t iQueueFamilyIndex; uint32_t iQueueFamilyIndex;
}; };
struct MNNVulkanTensorContent {
VkBuffer buffer;
VkDeviceSize size;
VkDeviceSize offset;
halide_type_t realType;
int32_t mask; // For future usage
};
#endif #endif
#ifdef MNN_METAL #ifdef MNN_METAL
@ -36,6 +45,9 @@ struct MNNMetalTensorContent {
id<MTLBuffer> buffer; id<MTLBuffer> buffer;
int32_t offset; int32_t offset;
id<MTLTexture> texture; id<MTLTexture> texture;
halide_type_t type;
int32_t mask;
int32_t forFuture[8]; int32_t forFuture[8];
}; };

View File

@ -275,6 +275,12 @@ public:
mBuffer.dim[index].extent = length; mBuffer.dim[index].extent = length;
} }
/**
* @brief For GPU and Other Device, get memory directly, see MNNSharedContext for detail
* @return Success or not. If type != tensor's backend's type or type is cpu , return false
*/
bool getDeviceInfo(void* dst, int forwardType) const;
public: public:
/** /**
* @brief print tensor data. for DEBUG use only. * @brief print tensor data. for DEBUG use only.

View File

@ -267,6 +267,7 @@ private:
bool mVisited = false; bool mVisited = false;
std::vector<WeakEXPRP> mTo; std::vector<WeakEXPRP> mTo;
bool mCanDecompose = true; bool mCanDecompose = true;
friend class ExprModule;
}; };
} // namespace Express } // namespace Express

View File

@ -77,6 +77,7 @@ MNN_PUBLIC VARP _ChangeInputFormat(VARP input, Dimensionformat format);
MNN_PUBLIC VARP _Conv2DBackPropFilter(VARP input, VARP inputGrad, INTS kernelSize, PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); MNN_PUBLIC VARP _Conv2DBackPropFilter(VARP input, VARP inputGrad, INTS kernelSize, PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0});
MNN_PUBLIC VARP _PoolGrad(VARP originInput, VARP originOutput, VARP inputGrad, INTS kernel, INTS stride, PoolingMode type, PaddingMode pad = VALID, INTS pads= {0, 0}); MNN_PUBLIC VARP _PoolGrad(VARP originInput, VARP originOutput, VARP inputGrad, INTS kernel, INTS stride, PoolingMode type, PaddingMode pad = VALID, INTS pads= {0, 0});
// FIXME: move the api to Array Ops // FIXME: move the api to Array Ops
MNN_PUBLIC VARP _Reverse(VARP x, VARP axis);
MNN_PUBLIC VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim); MNN_PUBLIC VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim);
// FIXME: move the api to Image Ops // FIXME: move the api to Image Ops
MNN_PUBLIC VARP _Crop(VARP images, VARP size, int axis, INTS offset); MNN_PUBLIC VARP _Crop(VARP images, VARP size, int axis, INTS offset);

View File

@ -1,64 +0,0 @@
//
// cli_demo.cpp
//
// Created by MNN on 2023/03/24.
// ZhaodeWang
//
#include "llm.hpp"
#include <fstream>
#include <stdlib.h>
void benchmark(Llm* llm, std::string prompt_file) {
std::cout << "prompt file is " << prompt_file << std::endl;
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
// prompt start with '#' will be ignored
if (prompt.substr(0, 1) == "#") {
continue;
}
prompts.push_back(prompt);
}
int prompt_len = 0;
int decode_len = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
// llm->warmup();
for (int i = 0; i < prompts.size(); i++) {
llm->response(prompts[i]);
prompt_len += llm->prompt_len_;
decode_len += llm->gen_seq_len_;
prefill_time += llm->prefill_us_;
decode_time += llm->decode_us_;
llm->reset();
}
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
printf("\n#################################\n");
printf("prompt tokens num = %d\n", prompt_len);
printf("decode tokens num = %d\n", decode_len);
printf("prefill time = %.2f s\n", prefill_s);
printf(" decode time = %.2f s\n", decode_s);
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
printf("##################################\n");
}
int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: " << argv[0] << " model_dir <prompt.txt>" << std::endl;
return 0;
}
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
if (argc < 3) {
llm->chat();
}
std::string prompt_file = argv[2];
benchmark(llm.get(), prompt_file);
return 0;
}

View File

@ -11,8 +11,10 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <iostream> #include <iostream>
#include <streambuf>
#include <functional>
#include <unordered_map>
#include <MNN/AutoTime.hpp> #include <MNN/AutoTime.hpp>
#include <MNN/expr/Expr.hpp> #include <MNN/expr/Expr.hpp>
@ -25,6 +27,25 @@ using namespace MNN;
using namespace Express; using namespace Express;
class Tokenizer; class Tokenizer;
// llm stream buffer with callback
class LlmStreamBuffer : public std::streambuf {
public:
using CallBack = std::function<void(const char* str, size_t len)>;;
LlmStreamBuffer(CallBack callback) : callback_(callback) {}
protected:
virtual std::streamsize xsputn(const char* s, std::streamsize n) override {
if (callback_) {
callback_(s, n);
}
return n;
}
private:
CallBack callback_ = nullptr;
};
class MNN_PUBLIC Llm { class MNN_PUBLIC Llm {
public: public:
Llm() { Llm() {

View File

@ -80,9 +80,8 @@ public:
virtual std::vector<int> encode(const std::string& str) override; virtual std::vector<int> encode(const std::string& str) override;
virtual std::string decode(int id) override; virtual std::string decode(int id) override;
private: private:
std::unordered_map<std::string, int> encoder_;
std::vector<std::string> decoder_; std::vector<std::string> decoder_;
std::vector<int> tokens_;
std::vector<int> token_ids_;
}; };
#endif // TOKENIZER_hpp #endif // TOKENIZER_hpp

View File

@ -106,8 +106,7 @@ std::string Llm::response(const std::string& query, std::ostream* os, const char
history_ = input_ids; history_ = input_ids;
} }
prompt_len_ = input_ids.size(); prompt_len_ = static_cast<int>(input_ids.size());
// printf("token_num : %lu\n", input_ids.size());
auto st = std::chrono::system_clock::now(); auto st = std::chrono::system_clock::now();
int token = forward(input_ids); int token = forward(input_ids);
auto et = std::chrono::system_clock::now(); auto et = std::chrono::system_clock::now();
@ -168,20 +167,22 @@ void Llm::load(const std::string& model_dir) {
config.type = MNN_FORWARD_CPU; config.type = MNN_FORWARD_CPU;
// config.type = MNN_FORWARD_OPENCL; // config.type = MNN_FORWARD_OPENCL;
config.numThread = 4; config.numThread = 4;
// cpuBackendConfig.precision = BackendConfig::Precision_Low; cpuBackendConfig.precision = BackendConfig::Precision_Low;
cpuBackendConfig.memory = BackendConfig::Memory_Low; cpuBackendConfig.memory = BackendConfig::Memory_Low;
config.backendConfig = &cpuBackendConfig; config.backendConfig = &cpuBackendConfig;
runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config)); runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config));
if (config.type == MNN_FORWARD_OPENCL) { if (config.type == MNN_FORWARD_OPENCL) {
const char* cacheFileName = ".tempcache"; const char* cacheFileName = ".tempcache";
runtime_manager_->setCache(cacheFileName); // runtime_manager_->setCache(cacheFileName);
} }
load_progress_ = 0.f; load_progress_ = 0.f;
printf("load tokenizer\n");
// 1. load vocab // 1. load vocab
std::string tokenizer_path = model_dir + "/tokenizer.txt"; std::string tokenizer_path = model_dir + "/tokenizer.txt";
load_progress_ += 5.f; load_progress_ += 5.f;
tokenizer_->load(tokenizer_path); tokenizer_->load(tokenizer_path);
load_progress_ += 5.f; load_progress_ += 5.f;
printf("load tokenizer Done\n");
// 2. load model // 2. load model
Module::Config module_config; Module::Config module_config;
module_config.shapeMutable = true; module_config.shapeMutable = true;
@ -228,7 +229,7 @@ void Llm::load(const std::string& model_dir) {
} }
} }
if (config.type == MNN_FORWARD_OPENCL) { if (config.type == MNN_FORWARD_OPENCL) {
warmup(); // warmup();
} }
} }
@ -369,8 +370,10 @@ bool Chatglm_6b::is_stop(int token_id) {
std::vector<int> Chatglm2_6b::tokenizer(const std::string& query) { std::vector<int> Chatglm2_6b::tokenizer(const std::string& query) {
auto prompt = "问:" + query + "\n答:"; auto prompt = "问:" + query + "\n答:";
auto ids = tokenizer_encode(prompt); auto ids = tokenizer_encode(prompt);
ids.insert(ids.begin(), 64792); if (history_.empty()) {
ids.insert(ids.begin(), 64790); ids.insert(ids.begin(), 64792);
ids.insert(ids.begin(), 64790);
}
return ids; return ids;
} }

View File

@ -307,81 +307,48 @@ const int CHARACTER_VOCABULARY_SIZE = 256;
bool Tiktoken::load(const std::string& filename) { bool Tiktoken::load(const std::string& filename) {
std::ifstream tok_file(filename); std::ifstream tok_file(filename);
int index = -1, start = 0, rough_count = 0;
std::string token; std::string token;
while (tok_file >> token) { while (tok_file >> token) {
token = base64_decode(token); token = base64_decode(token);
encoder_[token] = static_cast<int>(decoder_.size());
decoder_.push_back(token); decoder_.push_back(token);
rough_count += token.size();
} }
tok_file.close(); tok_file.close();
tokens_.resize(rough_count * CHARACTER_VOCABULARY_SIZE, -1);
token_ids_.resize(rough_count * CHARACTER_VOCABULARY_SIZE, -1);
for (int n = 0; n < decoder_.size(); n++) {
token = decoder_[n];
int root = 0;
for (int i = 0; i < token.size(); i++) {
unsigned char x = token[i];
// record the token id at the parent of leaf node
if (i == token.size() - 1) {
token_ids_[root + x] = n;
}
// trace down a tree node.
// insert a subtree when needed.
if (tokens_[root + x] == -1) {
start += CHARACTER_VOCABULARY_SIZE;
tokens_[root + x] = start;
root = start;
} else {
root = tokens_[root + x];
}
}
}
tokens_.resize(start + CHARACTER_VOCABULARY_SIZE);
token_ids_.resize(start + CHARACTER_VOCABULARY_SIZE);
tokens_.shrink_to_fit();
token_ids_.shrink_to_fit();
return true; return true;
} }
// ref: https://github.com/youkaichao/fast_bpe_tokenizer
std::vector<int> Tiktoken::encode(const std::string& str) { std::vector<int> Tiktoken::encode(const std::string& str) {
std::vector<int> ids; std::vector<int> ids;
if (str.empty()) { if (str.empty()) {
return ids; return ids;
} }
int i = 0; size_t i = 0;
int root = 0;
int root_token_id = -1;
int last_found_position = -1;
int last_found_token_id = -1;
while (i < str.size()) { while (i < str.size()) {
unsigned char x = str[i]; bool found_pair = false;
bool should_fall_back = false; // Attempt to match the longest possible symbol
if (tokens_[root + x] != -1) { size_t longest_match_len = 0;
root_token_id = token_ids_[root + x]; std::string longest_match;
root = tokens_[root + x];
if (root_token_id != -1) { // Check substrings of decreasing length
// a token ends at position i for (size_t len = str.size() - i; len > 0; --len) {
last_found_position = i; std::string token = str.substr(i, len);
last_found_token_id = root_token_id; auto it = encoder_.find(token);
if (it != encoder_.end()) {
if (len > longest_match_len) {
longest_match_len = len;
longest_match = it->first;
}
} }
i++;
if (i == str.size()) {
should_fall_back = true;
}
} else {
// assert(last_found_position != -1);
should_fall_back = true;
} }
if (should_fall_back) {
i = last_found_position + 1; if (!longest_match.empty()) {
ids.push_back(last_found_token_id); ids.push_back(encoder_.at(longest_match));
// start searching from the root again i += longest_match_len;
root = 0; } else {
root_token_id = -1; // If no matching symbol is found, this typically means an error in the encoding
last_found_position = -1; // or the input text contains characters that the encoder doesn't know how to handle
last_found_token_id = -1; std::cerr << "Error: No encoding found for the sequence starting at position " << i << std::endl;
return {};
} }
} }
return ids; return ids;

View File

@ -8,6 +8,7 @@ adb push ./libMNN_Vulkan.so /data/local/tmp/$DIR/libMNN_Vulkan.so
adb push ./libMNN_GL.so /data/local/tmp/$DIR/libMNN_GL.so adb push ./libMNN_GL.so /data/local/tmp/$DIR/libMNN_GL.so
adb push ./libMNN_Express.so /data/local/tmp/$DIR/libMNN_Express.so adb push ./libMNN_Express.so /data/local/tmp/$DIR/libMNN_Express.so
adb push ./MNNV2Basic.out /data/local/tmp/$DIR/MNNV2Basic.out adb push ./MNNV2Basic.out /data/local/tmp/$DIR/MNNV2Basic.out
adb push ./ModuleBasic.out /data/local/tmp/$DIR/ModuleBasic.out
adb shell "cd /data/local/tmp/$DIR && rm -r output" adb shell "cd /data/local/tmp/$DIR && rm -r output"
adb shell "cd /data/local/tmp/$DIR && mkdir output" adb shell "cd /data/local/tmp/$DIR && mkdir output"
adb push ./unitTest.out /data/local/tmp/$DIR/unitTest.out adb push ./unitTest.out /data/local/tmp/$DIR/unitTest.out

View File

@ -163,14 +163,12 @@
4896D37E25FE2A6B00717702 /* Arm82MNNPackForMatMul_A.S in Sources */ = {isa = PBXBuildFile; fileRef = 4896D37625FE2A6B00717702 /* Arm82MNNPackForMatMul_A.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; }; 4896D37E25FE2A6B00717702 /* Arm82MNNPackForMatMul_A.S in Sources */ = {isa = PBXBuildFile; fileRef = 4896D37625FE2A6B00717702 /* Arm82MNNPackForMatMul_A.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
4896D37F25FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 4896D37725FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; }; 4896D37F25FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 4896D37725FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A172550FDC800AD896A /* MetalReduction.hpp */; }; 489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A172550FDC800AD896A /* MetalReduction.hpp */; };
489D7A6A2550FDC800AD896A /* MetalConvolutionGEMM.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A192550FDC800AD896A /* MetalConvolutionGEMM.hpp */; };
489D7A6E2550FDC800AD896A /* MetalROIPooling.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A1D2550FDC800AD896A /* MetalROIPooling.hpp */; }; 489D7A6E2550FDC800AD896A /* MetalROIPooling.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A1D2550FDC800AD896A /* MetalROIPooling.hpp */; };
489D7A6F2550FDC800AD896A /* MetalCast.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A1E2550FDC800AD896A /* MetalCast.mm */; }; 489D7A6F2550FDC800AD896A /* MetalCast.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A1E2550FDC800AD896A /* MetalCast.mm */; };
489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A1F2550FDC800AD896A /* MetalRaster.hpp */; }; 489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A1F2550FDC800AD896A /* MetalRaster.hpp */; };
489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A212550FDC800AD896A /* MetalReLU6.hpp */; }; 489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A212550FDC800AD896A /* MetalReLU6.hpp */; };
489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A222550FDC800AD896A /* MetalBackend.hpp */; }; 489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A222550FDC800AD896A /* MetalBackend.hpp */; };
489D7A762550FDC800AD896A /* MetalReduction.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A252550FDC800AD896A /* MetalReduction.mm */; }; 489D7A762550FDC800AD896A /* MetalReduction.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A252550FDC800AD896A /* MetalReduction.mm */; };
489D7A772550FDC800AD896A /* MetalConvolutionGEMM.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A262550FDC800AD896A /* MetalConvolutionGEMM.mm */; };
489D7A782550FDC800AD896A /* MetalEltwise.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A272550FDC800AD896A /* MetalEltwise.mm */; }; 489D7A782550FDC800AD896A /* MetalEltwise.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A272550FDC800AD896A /* MetalEltwise.mm */; };
489D7A792550FDC800AD896A /* MetalConvolution1x1.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A282550FDC800AD896A /* MetalConvolution1x1.mm */; }; 489D7A792550FDC800AD896A /* MetalConvolution1x1.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A282550FDC800AD896A /* MetalConvolution1x1.mm */; };
489D7A7B2550FDC800AD896A /* MetalUnary.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2A2550FDC800AD896A /* MetalUnary.hpp */; }; 489D7A7B2550FDC800AD896A /* MetalUnary.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A2A2550FDC800AD896A /* MetalUnary.hpp */; };
@ -206,7 +204,6 @@
489D7AA72550FDC900AD896A /* MetalScale.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A562550FDC800AD896A /* MetalScale.hpp */; }; 489D7AA72550FDC900AD896A /* MetalScale.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A562550FDC800AD896A /* MetalScale.hpp */; };
489D7AA82550FDC900AD896A /* MetalCast.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A572550FDC800AD896A /* MetalCast.hpp */; }; 489D7AA82550FDC900AD896A /* MetalCast.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A572550FDC800AD896A /* MetalCast.hpp */; };
489D7AAF2550FDC900AD896A /* MetalConvolutionWinograd.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A5E2550FDC800AD896A /* MetalConvolutionWinograd.mm */; }; 489D7AAF2550FDC900AD896A /* MetalConvolutionWinograd.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A5E2550FDC800AD896A /* MetalConvolutionWinograd.mm */; };
489D7AB02550FDC900AD896A /* MetalDefine.h in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A5F2550FDC800AD896A /* MetalDefine.h */; };
489D7AB32550FDC900AD896A /* MetalPReLU.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A622550FDC800AD896A /* MetalPReLU.mm */; }; 489D7AB32550FDC900AD896A /* MetalPReLU.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A622550FDC800AD896A /* MetalPReLU.mm */; };
489D7AB42550FDC900AD896A /* MetalBinary.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A632550FDC800AD896A /* MetalBinary.hpp */; }; 489D7AB42550FDC900AD896A /* MetalBinary.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 489D7A632550FDC800AD896A /* MetalBinary.hpp */; };
489D7AB62550FDC900AD896A /* MetalReLU6.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A652550FDC800AD896A /* MetalReLU6.mm */; }; 489D7AB62550FDC900AD896A /* MetalReLU6.mm in Sources */ = {isa = PBXBuildFile; fileRef = 489D7A652550FDC800AD896A /* MetalReLU6.mm */; };
@ -283,7 +280,6 @@
4AF4FB2D269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */; }; 4AF4FB2D269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */; };
4AF4FB2E269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */; }; 4AF4FB2E269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */; };
4D0C80E32862FC4100C7CAD6 /* CoreMLOPRegister.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D0C80E22862FC4100C7CAD6 /* CoreMLOPRegister.cpp */; }; 4D0C80E32862FC4100C7CAD6 /* CoreMLOPRegister.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D0C80E22862FC4100C7CAD6 /* CoreMLOPRegister.cpp */; };
4D0C80E52862FC4700C7CAD6 /* CoreMLRaster.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4D0C80E42862FC4700C7CAD6 /* CoreMLRaster.metal */; };
4D4CF4672760946500A36D9F /* miscellaneous.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D4CF4622760946500A36D9F /* miscellaneous.cpp */; }; 4D4CF4672760946500A36D9F /* miscellaneous.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D4CF4622760946500A36D9F /* miscellaneous.cpp */; };
4D4CF4682760946500A36D9F /* geometric.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D4CF4632760946500A36D9F /* geometric.cpp */; }; 4D4CF4682760946500A36D9F /* geometric.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D4CF4632760946500A36D9F /* geometric.cpp */; };
4D4CF4692760946500A36D9F /* filter.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D4CF4642760946500A36D9F /* filter.cpp */; }; 4D4CF4692760946500A36D9F /* filter.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4D4CF4642760946500A36D9F /* filter.cpp */; };
@ -771,8 +767,11 @@
CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */; }; CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */; };
CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */; }; CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */; };
CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */; }; CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */; };
CE8049AC2B31C65B009B422C /* CPULayerNorm.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE8049A92B31C65B009B422C /* CPULayerNorm.hpp */; };
CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */; }; CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */; };
CE9AFED728E54E3300566949 /* CPUInterp3D.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */; }; CE9AFED728E54E3300566949 /* CPUInterp3D.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */; };
CEA49AA82AFD010900971CB7 /* MetalExecution.mm in Sources */ = {isa = PBXBuildFile; fileRef = CEA49AA62AFD010900971CB7 /* MetalExecution.mm */; };
CEA49AA92AFD010900971CB7 /* MetalExecution.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CEA49AA72AFD010900971CB7 /* MetalExecution.hpp */; };
CEA82BDB2A15F8AD002CBC95 /* IdstConvolutionInt8.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CEA82BD92A15F8AD002CBC95 /* IdstConvolutionInt8.cpp */; }; CEA82BDB2A15F8AD002CBC95 /* IdstConvolutionInt8.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CEA82BD92A15F8AD002CBC95 /* IdstConvolutionInt8.cpp */; };
CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CEA82BDA2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp */; }; CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CEA82BDA2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp */; };
CEDB20EB2846D07100AE9DC4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = CEDB20EA2846D07100AE9DC4 /* AppDelegate.m */; }; CEDB20EB2846D07100AE9DC4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = CEDB20EA2846D07100AE9DC4 /* AppDelegate.m */; };
@ -984,14 +983,12 @@
4896D37625FE2A6B00717702 /* Arm82MNNPackForMatMul_A.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = Arm82MNNPackForMatMul_A.S; path = ../../../arm82/asm/arm64/Arm82MNNPackForMatMul_A.S; sourceTree = "<group>"; }; 4896D37625FE2A6B00717702 /* Arm82MNNPackForMatMul_A.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = Arm82MNNPackForMatMul_A.S; path = ../../../arm82/asm/arm64/Arm82MNNPackForMatMul_A.S; sourceTree = "<group>"; };
4896D37725FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNConvRunForLineDepthwiseFP16.S; path = ../../../arm82/asm/arm64/MNNConvRunForLineDepthwiseFP16.S; sourceTree = "<group>"; }; 4896D37725FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNConvRunForLineDepthwiseFP16.S; path = ../../../arm82/asm/arm64/MNNConvRunForLineDepthwiseFP16.S; sourceTree = "<group>"; };
489D7A172550FDC800AD896A /* MetalReduction.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalReduction.hpp; sourceTree = "<group>"; }; 489D7A172550FDC800AD896A /* MetalReduction.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalReduction.hpp; sourceTree = "<group>"; };
489D7A192550FDC800AD896A /* MetalConvolutionGEMM.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalConvolutionGEMM.hpp; sourceTree = "<group>"; };
489D7A1D2550FDC800AD896A /* MetalROIPooling.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalROIPooling.hpp; sourceTree = "<group>"; }; 489D7A1D2550FDC800AD896A /* MetalROIPooling.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalROIPooling.hpp; sourceTree = "<group>"; };
489D7A1E2550FDC800AD896A /* MetalCast.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalCast.mm; sourceTree = "<group>"; }; 489D7A1E2550FDC800AD896A /* MetalCast.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalCast.mm; sourceTree = "<group>"; };
489D7A1F2550FDC800AD896A /* MetalRaster.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalRaster.hpp; sourceTree = "<group>"; }; 489D7A1F2550FDC800AD896A /* MetalRaster.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalRaster.hpp; sourceTree = "<group>"; };
489D7A212550FDC800AD896A /* MetalReLU6.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalReLU6.hpp; sourceTree = "<group>"; }; 489D7A212550FDC800AD896A /* MetalReLU6.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalReLU6.hpp; sourceTree = "<group>"; };
489D7A222550FDC800AD896A /* MetalBackend.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalBackend.hpp; sourceTree = "<group>"; }; 489D7A222550FDC800AD896A /* MetalBackend.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalBackend.hpp; sourceTree = "<group>"; };
489D7A252550FDC800AD896A /* MetalReduction.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalReduction.mm; sourceTree = "<group>"; }; 489D7A252550FDC800AD896A /* MetalReduction.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalReduction.mm; sourceTree = "<group>"; };
489D7A262550FDC800AD896A /* MetalConvolutionGEMM.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolutionGEMM.mm; sourceTree = "<group>"; };
489D7A272550FDC800AD896A /* MetalEltwise.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalEltwise.mm; sourceTree = "<group>"; }; 489D7A272550FDC800AD896A /* MetalEltwise.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalEltwise.mm; sourceTree = "<group>"; };
489D7A282550FDC800AD896A /* MetalConvolution1x1.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolution1x1.mm; sourceTree = "<group>"; }; 489D7A282550FDC800AD896A /* MetalConvolution1x1.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolution1x1.mm; sourceTree = "<group>"; };
489D7A2A2550FDC800AD896A /* MetalUnary.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalUnary.hpp; sourceTree = "<group>"; }; 489D7A2A2550FDC800AD896A /* MetalUnary.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalUnary.hpp; sourceTree = "<group>"; };
@ -1027,7 +1024,6 @@
489D7A562550FDC800AD896A /* MetalScale.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalScale.hpp; sourceTree = "<group>"; }; 489D7A562550FDC800AD896A /* MetalScale.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalScale.hpp; sourceTree = "<group>"; };
489D7A572550FDC800AD896A /* MetalCast.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalCast.hpp; sourceTree = "<group>"; }; 489D7A572550FDC800AD896A /* MetalCast.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalCast.hpp; sourceTree = "<group>"; };
489D7A5E2550FDC800AD896A /* MetalConvolutionWinograd.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolutionWinograd.mm; sourceTree = "<group>"; }; 489D7A5E2550FDC800AD896A /* MetalConvolutionWinograd.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalConvolutionWinograd.mm; sourceTree = "<group>"; };
489D7A5F2550FDC800AD896A /* MetalDefine.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = MetalDefine.h; sourceTree = "<group>"; };
489D7A622550FDC800AD896A /* MetalPReLU.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalPReLU.mm; sourceTree = "<group>"; }; 489D7A622550FDC800AD896A /* MetalPReLU.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalPReLU.mm; sourceTree = "<group>"; };
489D7A632550FDC800AD896A /* MetalBinary.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalBinary.hpp; sourceTree = "<group>"; }; 489D7A632550FDC800AD896A /* MetalBinary.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalBinary.hpp; sourceTree = "<group>"; };
489D7A652550FDC800AD896A /* MetalReLU6.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalReLU6.mm; sourceTree = "<group>"; }; 489D7A652550FDC800AD896A /* MetalReLU6.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalReLU6.mm; sourceTree = "<group>"; };
@ -1104,7 +1100,6 @@
4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx1.S; sourceTree = "<group>"; }; 4AF4FB2B269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx1.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx1.S; sourceTree = "<group>"; };
4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx4.S; sourceTree = "<group>"; }; 4AF4FB2C269ED24C005BA97B /* MNNPackedSparseQuantMatMulEpx4.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackedSparseQuantMatMulEpx4.S; sourceTree = "<group>"; };
4D0C80E22862FC4100C7CAD6 /* CoreMLOPRegister.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CoreMLOPRegister.cpp; sourceTree = "<group>"; }; 4D0C80E22862FC4100C7CAD6 /* CoreMLOPRegister.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CoreMLOPRegister.cpp; sourceTree = "<group>"; };
4D0C80E42862FC4700C7CAD6 /* CoreMLRaster.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = CoreMLRaster.metal; sourceTree = "<group>"; };
4D4CF4622760946500A36D9F /* miscellaneous.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = miscellaneous.cpp; sourceTree = "<group>"; }; 4D4CF4622760946500A36D9F /* miscellaneous.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = miscellaneous.cpp; sourceTree = "<group>"; };
4D4CF4632760946500A36D9F /* geometric.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = geometric.cpp; sourceTree = "<group>"; }; 4D4CF4632760946500A36D9F /* geometric.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = geometric.cpp; sourceTree = "<group>"; };
4D4CF4642760946500A36D9F /* filter.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = filter.cpp; sourceTree = "<group>"; }; 4D4CF4642760946500A36D9F /* filter.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = filter.cpp; sourceTree = "<group>"; };
@ -1603,8 +1598,11 @@
CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = "<group>"; }; CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = "<group>"; };
CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = "<group>"; }; CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = "<group>"; };
CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ShapeConvTranspose3D.cpp; sourceTree = "<group>"; }; CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ShapeConvTranspose3D.cpp; sourceTree = "<group>"; };
CE8049A92B31C65B009B422C /* CPULayerNorm.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPULayerNorm.hpp; sourceTree = "<group>"; };
CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUInterp3D.cpp; sourceTree = "<group>"; }; CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUInterp3D.cpp; sourceTree = "<group>"; };
CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUInterp3D.hpp; sourceTree = "<group>"; }; CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUInterp3D.hpp; sourceTree = "<group>"; };
CEA49AA62AFD010900971CB7 /* MetalExecution.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalExecution.mm; sourceTree = "<group>"; };
CEA49AA72AFD010900971CB7 /* MetalExecution.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = MetalExecution.hpp; sourceTree = "<group>"; };
CEA82BD92A15F8AD002CBC95 /* IdstConvolutionInt8.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = IdstConvolutionInt8.cpp; sourceTree = "<group>"; }; CEA82BD92A15F8AD002CBC95 /* IdstConvolutionInt8.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = IdstConvolutionInt8.cpp; sourceTree = "<group>"; };
CEA82BDA2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = IdstConvolutionInt8.hpp; sourceTree = "<group>"; }; CEA82BDA2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = IdstConvolutionInt8.hpp; sourceTree = "<group>"; };
CEDB20E72846D07100AE9DC4 /* demo.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = demo.app; sourceTree = BUILT_PRODUCTS_DIR; }; CEDB20E72846D07100AE9DC4 /* demo.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = demo.app; sourceTree = BUILT_PRODUCTS_DIR; };
@ -1913,6 +1911,7 @@
48887410215B639D0079B12E /* cpu */ = { 48887410215B639D0079B12E /* cpu */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
CE8049A92B31C65B009B422C /* CPULayerNorm.hpp */,
958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */, 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */,
CEE9B95F2A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp */, CEE9B95F2A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp */,
CEE9B95E2A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp */, CEE9B95E2A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp */,
@ -2096,6 +2095,8 @@
489D7A152550FDC800AD896A /* metal */ = { 489D7A152550FDC800AD896A /* metal */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
CEA49AA72AFD010900971CB7 /* MetalExecution.hpp */,
CEA49AA62AFD010900971CB7 /* MetalExecution.mm */,
4D566298299341270031C1A1 /* MetalFuse.hpp */, 4D566298299341270031C1A1 /* MetalFuse.hpp */,
4D566299299341270031C1A1 /* MetalFuse.mm */, 4D566299299341270031C1A1 /* MetalFuse.mm */,
19D0FE73285C66F200B74B1A /* MetalLayerNorm.hpp */, 19D0FE73285C66F200B74B1A /* MetalLayerNorm.hpp */,
@ -2110,14 +2111,12 @@
4838EA802611C00B0027232C /* MetalGridSample.hpp */, 4838EA802611C00B0027232C /* MetalGridSample.hpp */,
4838EA822611C00B0027232C /* MetalGridSample.mm */, 4838EA822611C00B0027232C /* MetalGridSample.mm */,
489D7A172550FDC800AD896A /* MetalReduction.hpp */, 489D7A172550FDC800AD896A /* MetalReduction.hpp */,
489D7A192550FDC800AD896A /* MetalConvolutionGEMM.hpp */,
489D7A1D2550FDC800AD896A /* MetalROIPooling.hpp */, 489D7A1D2550FDC800AD896A /* MetalROIPooling.hpp */,
489D7A1E2550FDC800AD896A /* MetalCast.mm */, 489D7A1E2550FDC800AD896A /* MetalCast.mm */,
489D7A1F2550FDC800AD896A /* MetalRaster.hpp */, 489D7A1F2550FDC800AD896A /* MetalRaster.hpp */,
489D7A212550FDC800AD896A /* MetalReLU6.hpp */, 489D7A212550FDC800AD896A /* MetalReLU6.hpp */,
489D7A222550FDC800AD896A /* MetalBackend.hpp */, 489D7A222550FDC800AD896A /* MetalBackend.hpp */,
489D7A252550FDC800AD896A /* MetalReduction.mm */, 489D7A252550FDC800AD896A /* MetalReduction.mm */,
489D7A262550FDC800AD896A /* MetalConvolutionGEMM.mm */,
489D7A272550FDC800AD896A /* MetalEltwise.mm */, 489D7A272550FDC800AD896A /* MetalEltwise.mm */,
489D7A282550FDC800AD896A /* MetalConvolution1x1.mm */, 489D7A282550FDC800AD896A /* MetalConvolution1x1.mm */,
489D7A2A2550FDC800AD896A /* MetalUnary.hpp */, 489D7A2A2550FDC800AD896A /* MetalUnary.hpp */,
@ -2153,7 +2152,6 @@
489D7A562550FDC800AD896A /* MetalScale.hpp */, 489D7A562550FDC800AD896A /* MetalScale.hpp */,
489D7A572550FDC800AD896A /* MetalCast.hpp */, 489D7A572550FDC800AD896A /* MetalCast.hpp */,
489D7A5E2550FDC800AD896A /* MetalConvolutionWinograd.mm */, 489D7A5E2550FDC800AD896A /* MetalConvolutionWinograd.mm */,
489D7A5F2550FDC800AD896A /* MetalDefine.h */,
489D7A622550FDC800AD896A /* MetalPReLU.mm */, 489D7A622550FDC800AD896A /* MetalPReLU.mm */,
489D7A632550FDC800AD896A /* MetalBinary.hpp */, 489D7A632550FDC800AD896A /* MetalBinary.hpp */,
489D7A652550FDC800AD896A /* MetalReLU6.mm */, 489D7A652550FDC800AD896A /* MetalReLU6.mm */,
@ -2293,7 +2291,6 @@
4D9A933526255BDA00F9B43C /* backend */ = { 4D9A933526255BDA00F9B43C /* backend */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
4D0C80E42862FC4700C7CAD6 /* CoreMLRaster.metal */,
4D0C80E22862FC4100C7CAD6 /* CoreMLOPRegister.cpp */, 4D0C80E22862FC4100C7CAD6 /* CoreMLOPRegister.cpp */,
4D4DAE67263905390060D37E /* CoreMLDefine.h */, 4D4DAE67263905390060D37E /* CoreMLDefine.h */,
4DDE2018263809920085AC8F /* CoreMLExecutorWrapper.h */, 4DDE2018263809920085AC8F /* CoreMLExecutorWrapper.h */,
@ -2891,6 +2888,7 @@
C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */, C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */,
1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */, 1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */,
1F501F872397BA5B004E8721 /* Matrix.h in Headers */, 1F501F872397BA5B004E8721 /* Matrix.h in Headers */,
CE8049AC2B31C65B009B422C /* CPULayerNorm.hpp in Headers */,
CECF8C5A299CACFD00D3875B /* WorkerThread.hpp in Headers */, CECF8C5A299CACFD00D3875B /* WorkerThread.hpp in Headers */,
48C84B85250F711700EE7666 /* IfModule.hpp in Headers */, 48C84B85250F711700EE7666 /* IfModule.hpp in Headers */,
4D9A937326255BDA00F9B43C /* CoreMLUnary.hpp in Headers */, 4D9A937326255BDA00F9B43C /* CoreMLUnary.hpp in Headers */,
@ -2913,7 +2911,6 @@
92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */, 92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */,
4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */, 4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */,
CECF8C85299CAD9400D3875B /* log_util.h in Headers */, CECF8C85299CAD9400D3875B /* log_util.h in Headers */,
489D7AB02550FDC900AD896A /* MetalDefine.h in Headers */,
4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */, 4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */,
4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */, 4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */,
92FF027A23AA0B5A00AC97F6 /* CPUPool.hpp in Headers */, 92FF027A23AA0B5A00AC97F6 /* CPUPool.hpp in Headers */,
@ -2972,6 +2969,7 @@
4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */, 4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */,
92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */, 92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */,
4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */, 4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */,
CEA49AA92AFD010900971CB7 /* MetalExecution.hpp in Headers */,
92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */, 92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */,
48C84BA0250F725600EE7666 /* InitNet.hpp in Headers */, 48C84BA0250F725600EE7666 /* InitNet.hpp in Headers */,
92FF03C623AA0B5A00AC97F6 /* CPUNonMaxSuppressionV2.hpp in Headers */, 92FF03C623AA0B5A00AC97F6 /* CPUNonMaxSuppressionV2.hpp in Headers */,
@ -3091,7 +3089,6 @@
481C2DF125FE2CD6001ED6DF /* Arm82OptFunc.hpp in Headers */, 481C2DF125FE2CD6001ED6DF /* Arm82OptFunc.hpp in Headers */,
4A5BEC6026AAB3B30032F6BD /* CommonCompute.hpp in Headers */, 4A5BEC6026AAB3B30032F6BD /* CommonCompute.hpp in Headers */,
C43C8225251894F400A0FF84 /* WingoradGenerater.hpp in Headers */, C43C8225251894F400A0FF84 /* WingoradGenerater.hpp in Headers */,
489D7A6A2550FDC800AD896A /* MetalConvolutionGEMM.hpp in Headers */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
}; };
@ -3404,7 +3401,6 @@
92FF041E23AA0B7100AC97F6 /* ShapeRange.cpp in Sources */, 92FF041E23AA0B7100AC97F6 /* ShapeRange.cpp in Sources */,
489D7AA42550FDC900AD896A /* MetalROIPooling.mm in Sources */, 489D7AA42550FDC900AD896A /* MetalROIPooling.mm in Sources */,
92FF03B423AA0B5A00AC97F6 /* Convolution1x1Strassen.cpp in Sources */, 92FF03B423AA0B5A00AC97F6 /* Convolution1x1Strassen.cpp in Sources */,
489D7A772550FDC800AD896A /* MetalConvolutionGEMM.mm in Sources */,
92FF031623AA0B5A00AC97F6 /* MNNMatrixMax.S in Sources */, 92FF031623AA0B5A00AC97F6 /* MNNMatrixMax.S in Sources */,
92FF043A23AA0B7100AC97F6 /* ShapePermute.cpp in Sources */, 92FF043A23AA0B7100AC97F6 /* ShapePermute.cpp in Sources */,
489D7A8E2550FDC900AD896A /* MetalPooling.mm in Sources */, 489D7A8E2550FDC900AD896A /* MetalPooling.mm in Sources */,
@ -3502,6 +3498,7 @@
489D7A9A2550FDC900AD896A /* MetalConvolutionCommon.mm in Sources */, 489D7A9A2550FDC900AD896A /* MetalConvolutionCommon.mm in Sources */,
92FF044623AA0B7100AC97F6 /* ShapeInnerProduct.cpp in Sources */, 92FF044623AA0B7100AC97F6 /* ShapeInnerProduct.cpp in Sources */,
48123007269EA84800EB7ABA /* CPUUnique.cpp in Sources */, 48123007269EA84800EB7ABA /* CPUUnique.cpp in Sources */,
CEA49AA82AFD010900971CB7 /* MetalExecution.mm in Sources */,
92FF036F23AA0B5A00AC97F6 /* CPURuntime.cpp in Sources */, 92FF036F23AA0B5A00AC97F6 /* CPURuntime.cpp in Sources */,
92FF039D23AA0B5A00AC97F6 /* StrassenMatmulComputor.cpp in Sources */, 92FF039D23AA0B5A00AC97F6 /* StrassenMatmulComputor.cpp in Sources */,
92FF030B23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */, 92FF030B23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */,
@ -3594,7 +3591,6 @@
950B28E229F627E00002F454 /* MNNBinarySubInt8.S in Sources */, 950B28E229F627E00002F454 /* MNNBinarySubInt8.S in Sources */,
950B28F029F627F70002F454 /* MNNBinarySubInt8.S in Sources */, 950B28F029F627F70002F454 /* MNNBinarySubInt8.S in Sources */,
4A224A0C27D0C2D9000A9260 /* ConvolutionPackWinograd.cpp in Sources */, 4A224A0C27D0C2D9000A9260 /* ConvolutionPackWinograd.cpp in Sources */,
4D0C80E52862FC4700C7CAD6 /* CoreMLRaster.metal in Sources */,
92FF044123AA0B7100AC97F6 /* ShapeMoments.cpp in Sources */, 92FF044123AA0B7100AC97F6 /* ShapeMoments.cpp in Sources */,
950B28FA2A0C9AC20002F454 /* CPUScaleInt8.cpp in Sources */, 950B28FA2A0C9AC20002F454 /* CPUScaleInt8.cpp in Sources */,
4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */, 4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */,
@ -4164,7 +4160,7 @@
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.v3; PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.9999ve;
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
TARGETED_DEVICE_FAMILY = "1,2"; TARGETED_DEVICE_FAMILY = "1,2";
}; };
@ -4189,7 +4185,7 @@
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.v3; PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.9999ve;
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
TARGETED_DEVICE_FAMILY = "1,2"; TARGETED_DEVICE_FAMILY = "1,2";
}; };
@ -4205,7 +4201,7 @@
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CODE_SIGN_STYLE = Automatic; CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1; CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = Q48UX93J22; DEVELOPMENT_TEAM = 6G7464HHUS;
GENERATE_INFOPLIST_FILE = YES; GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = demo/Info.plist; INFOPLIST_FILE = demo/Info.plist;
INFOPLIST_KEY_NSCameraUsageDescription = "use camera to capture photo for demo"; INFOPLIST_KEY_NSCameraUsageDescription = "use camera to capture photo for demo";
@ -4221,7 +4217,7 @@
MARKETING_VERSION = 1.0; MARKETING_VERSION = 1.0;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES; MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111; PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.9999;
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_EMIT_LOC_STRINGS = YES;
TARGETED_DEVICE_FAMILY = "1,2"; TARGETED_DEVICE_FAMILY = "1,2";
@ -4238,7 +4234,7 @@
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CODE_SIGN_STYLE = Automatic; CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1; CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = Q48UX93J22; DEVELOPMENT_TEAM = 6G7464HHUS;
GENERATE_INFOPLIST_FILE = YES; GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = demo/Info.plist; INFOPLIST_FILE = demo/Info.plist;
INFOPLIST_KEY_NSCameraUsageDescription = "use camera to capture photo for demo"; INFOPLIST_KEY_NSCameraUsageDescription = "use camera to capture photo for demo";
@ -4253,7 +4249,7 @@
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
MARKETING_VERSION = 1.0; MARKETING_VERSION = 1.0;
MTL_FAST_MATH = YES; MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111; PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.9999;
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_EMIT_LOC_STRINGS = YES;
TARGETED_DEVICE_FAMILY = "1,2"; TARGETED_DEVICE_FAMILY = "1,2";

View File

@ -255,7 +255,7 @@ struct GpuCache {
[enc setTexture:inputTexture atIndex:0]; [enc setTexture:inputTexture atIndex:0];
[enc setBuffer:_cache->_constant offset:0 atIndex:1]; [enc setBuffer:_cache->_constant offset:0 atIndex:1];
MNNMetalTensorContent sharedContent; MNNMetalTensorContent sharedContent;
MNNMetalGetTensorContent(&sharedContent, _input); _input->getDeviceInfo(&sharedContent, MNN_FORWARD_METAL);
// For Metal Context to write, don't need finish, just use flush // For Metal Context to write, don't need finish, just use flush
_input->wait(MNN::Tensor::MAP_TENSOR_WRITE, false); _input->wait(MNN::Tensor::MAP_TENSOR_WRITE, false);
[enc setBuffer:sharedContent.buffer offset:sharedContent.offset atIndex:0]; [enc setBuffer:sharedContent.buffer offset:sharedContent.offset atIndex:0];

View File

@ -78,7 +78,7 @@ def build_deps():
if IS_WINDOWS: if IS_WINDOWS:
os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\ os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\
-DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\ -DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\
-DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNTrain MNNConvertDeps') -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNConvertDeps')
elif IS_LINUX: elif IS_LINUX:
extra_opts += '-DMNN_TENSORRT=ON \ extra_opts += '-DMNN_TENSORRT=ON \
-DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' ' -DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' '
@ -98,9 +98,9 @@ def build_deps():
extra_opts += ' -DMNN_BUILD_TORCH=ON ' if USE_TORCH else ' ' extra_opts += ' -DMNN_BUILD_TORCH=ON ' if USE_TORCH else ' '
print(extra_opts) print(extra_opts)
os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \ os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \
-DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\ -DMNN_BUILD_SHARED_LIBS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\
-DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \ -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
.. && make MNN MNNTrain MNNConvertDeps -j32') .. && make MNN MNNConvertDeps -j4')
################################################################################ ################################################################################
# Building dependent libraries # Building dependent libraries
################################################################################ ################################################################################

View File

@ -224,6 +224,9 @@ def configure_extension_build():
if USE_TRT: if USE_TRT:
engine_depend += trt_depend engine_depend += trt_depend
if IS_DARWIN:
lib_files += [('lib', [os.path.join(root_dir, BUILD_DIR, "libMNN.dylib")])]
lib_files += [('lib', [os.path.join(root_dir, BUILD_DIR, "tools","converter", "libMNNConvertDeps.dylib")])]
if USE_CUDA: if USE_CUDA:
engine_depend += cuda_depend engine_depend += cuda_depend
@ -307,9 +310,7 @@ def configure_extension_build():
if IS_DARWIN: if IS_DARWIN:
engine_link_args += ['-stdlib=libc++'] engine_link_args += ['-stdlib=libc++']
engine_link_args += ['-Wl,-all_load']
engine_link_args += engine_depend engine_link_args += engine_depend
engine_link_args += ['-Wl,-noall_load']
if IS_LINUX: if IS_LINUX:
engine_link_args += ['-Wl,--whole-archive'] engine_link_args += ['-Wl,--whole-archive']
engine_link_args += engine_depend engine_link_args += engine_depend
@ -318,9 +319,7 @@ def configure_extension_build():
if IS_WINDOWS: if IS_WINDOWS:
engine_link_args += ['/WHOLEARCHIVE:MNN.lib'] engine_link_args += ['/WHOLEARCHIVE:MNN.lib']
if IS_DARWIN: if IS_DARWIN:
tools_link_args += ['-Wl,-all_load']
tools_link_args += tools_depend tools_link_args += tools_depend
tools_link_args += ['-Wl,-noall_load']
if IS_LINUX: if IS_LINUX:
tools_link_args += ['-Wl,--whole-archive'] tools_link_args += ['-Wl,--whole-archive']
tools_link_args += tools_depend tools_link_args += tools_depend

View File

@ -1499,6 +1499,13 @@ static PyObject* PyMNNExpr_transpose(PyObject *self, PyObject *args) {
} }
PyMNN_ERROR("transpose require args: (Var, [int]|Var)"); PyMNN_ERROR("transpose require args: (Var, [int]|Var)");
} }
static PyObject* PyMNNExpr_reverse(PyObject *self, PyObject *args) {
PyObject *x, *y;
if (PyArg_ParseTuple(args, "OO", &x, &y) && isVar(x) && isVar(y)) {
return toPyObj(Express::_Reverse(toVar(x), toVar(y)));
}
PyMNN_ERROR("reverse require args: (Var, Var)");
}
static PyObject* PyMNNExpr_reverse_sequence(PyObject *self, PyObject *args) { static PyObject* PyMNNExpr_reverse_sequence(PyObject *self, PyObject *args) {
PyObject *x, *y; PyObject *x, *y;
int batchDim, seqDim; int batchDim, seqDim;
@ -1839,6 +1846,7 @@ static PyMethodDef PyMNNExpr_methods[] = {
{"transpose", PyMNNExpr_transpose, METH_VARARGS, "build transpose: (Var, [int]/Var)"}, {"transpose", PyMNNExpr_transpose, METH_VARARGS, "build transpose: (Var, [int]/Var)"},
register_methods(Expr, register_methods(Expr,
channel_shuffle, "build channel_shuffle expr", channel_shuffle, "build channel_shuffle expr",
reverse, "build reverse expr",
reverse_sequence, "build reverse_sequence expr", reverse_sequence, "build reverse_sequence expr",
crop, "build crop expr", crop, "build crop expr",
resize, "build resize expr", resize, "build resize expr",

View File

@ -76,12 +76,6 @@ struct BatchNormT;
struct Scale; struct Scale;
struct ScaleT; struct ScaleT;
struct QuantizeLinear;
struct QuantizeLinearT;
struct DequantizeLinear;
struct DequantizeLinearT;
struct Eltwise; struct Eltwise;
struct EltwiseT; struct EltwiseT;
@ -165,10 +159,6 @@ inline const flatbuffers::TypeTable *BatchNormTypeTable();
inline const flatbuffers::TypeTable *ScaleTypeTable(); inline const flatbuffers::TypeTable *ScaleTypeTable();
inline const flatbuffers::TypeTable *QuantizeLinearTypeTable();
inline const flatbuffers::TypeTable *DequantizeLinearTypeTable();
inline const flatbuffers::TypeTable *EltwiseTypeTable(); inline const flatbuffers::TypeTable *EltwiseTypeTable();
inline const flatbuffers::TypeTable *FlattenTypeTable(); inline const flatbuffers::TypeTable *FlattenTypeTable();
@ -1149,13 +1139,15 @@ struct QuantizedFloatParamT : public flatbuffers::NativeTable {
int8_t clampMin; int8_t clampMin;
int8_t clampMax; int8_t clampMax;
std::vector<int32_t> winogradAttr; std::vector<int32_t> winogradAttr;
DataType outputDataType;
QuantizedFloatParamT() QuantizedFloatParamT()
: method(QuantizeAlgo_DEFAULT), : method(QuantizeAlgo_DEFAULT),
nbits(8), nbits(8),
zeroPoint(0), zeroPoint(0),
outputZeroPoint(0), outputZeroPoint(0),
clampMin(-128), clampMin(-128),
clampMax(127) { clampMax(127),
outputDataType(DataType_DT_INT8) {
} }
}; };
@ -1197,6 +1189,9 @@ struct QuantizedFloatParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table
const flatbuffers::Vector<int32_t> *winogradAttr() const { const flatbuffers::Vector<int32_t> *winogradAttr() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(24); return GetPointer<const flatbuffers::Vector<int32_t> *>(24);
} }
DataType outputDataType() const {
return static_cast<DataType>(GetField<int32_t>(26, 6));
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyOffset(verifier, 4) && VerifyOffset(verifier, 4) &&
@ -1215,6 +1210,7 @@ struct QuantizedFloatParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table
VerifyField<int8_t>(verifier, 22) && VerifyField<int8_t>(verifier, 22) &&
VerifyOffset(verifier, 24) && VerifyOffset(verifier, 24) &&
verifier.VerifyVector(winogradAttr()) && verifier.VerifyVector(winogradAttr()) &&
VerifyField<int32_t>(verifier, 26) &&
verifier.EndTable(); verifier.EndTable();
} }
QuantizedFloatParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; QuantizedFloatParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -1258,6 +1254,9 @@ struct QuantizedFloatParamBuilder {
void add_winogradAttr(flatbuffers::Offset<flatbuffers::Vector<int32_t>> winogradAttr) { void add_winogradAttr(flatbuffers::Offset<flatbuffers::Vector<int32_t>> winogradAttr) {
fbb_.AddOffset(24, winogradAttr); fbb_.AddOffset(24, winogradAttr);
} }
void add_outputDataType(DataType outputDataType) {
fbb_.AddElement<int32_t>(26, static_cast<int32_t>(outputDataType), 6);
}
explicit QuantizedFloatParamBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit QuantizedFloatParamBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -1282,8 +1281,10 @@ inline flatbuffers::Offset<QuantizedFloatParam> CreateQuantizedFloatParam(
int8_t outputZeroPoint = 0, int8_t outputZeroPoint = 0,
int8_t clampMin = -128, int8_t clampMin = -128,
int8_t clampMax = 127, int8_t clampMax = 127,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> winogradAttr = 0) { flatbuffers::Offset<flatbuffers::Vector<int32_t>> winogradAttr = 0,
DataType outputDataType = DataType_DT_INT8) {
QuantizedFloatParamBuilder builder_(_fbb); QuantizedFloatParamBuilder builder_(_fbb);
builder_.add_outputDataType(outputDataType);
builder_.add_winogradAttr(winogradAttr); builder_.add_winogradAttr(winogradAttr);
builder_.add_nbits(nbits); builder_.add_nbits(nbits);
builder_.add_tensorScale(tensorScale); builder_.add_tensorScale(tensorScale);
@ -2922,180 +2923,6 @@ inline flatbuffers::Offset<Scale> CreateScale(
flatbuffers::Offset<Scale> CreateScale(flatbuffers::FlatBufferBuilder &_fbb, const ScaleT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); flatbuffers::Offset<Scale> CreateScale(flatbuffers::FlatBufferBuilder &_fbb, const ScaleT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct QuantizeLinearT : public flatbuffers::NativeTable {
typedef QuantizeLinear TableType;
int32_t scaleSize;
int32_t scaleAxis;
std::vector<float> scaleData;
std::vector<int8_t> zeroPointData;
QuantizeLinearT()
: scaleSize(0),
scaleAxis(0) {
}
};
struct QuantizeLinear FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef QuantizeLinearT NativeTableType;
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
return QuantizeLinearTypeTable();
}
int32_t scaleSize() const {
return GetField<int32_t>(4, 0);
}
int32_t scaleAxis() const {
return GetField<int32_t>(6, 0);
}
const flatbuffers::Vector<float> *scaleData() const {
return GetPointer<const flatbuffers::Vector<float> *>(8);
}
const flatbuffers::Vector<int8_t> *zeroPointData() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(10);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, 4) &&
VerifyField<int32_t>(verifier, 6) &&
VerifyOffset(verifier, 8) &&
verifier.VerifyVector(scaleData()) &&
VerifyOffset(verifier, 10) &&
verifier.VerifyVector(zeroPointData()) &&
verifier.EndTable();
}
QuantizeLinearT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(QuantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<QuantizeLinear> Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct QuantizeLinearBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_scaleSize(int32_t scaleSize) {
fbb_.AddElement<int32_t>(4, scaleSize, 0);
}
void add_scaleAxis(int32_t scaleAxis) {
fbb_.AddElement<int32_t>(6, scaleAxis, 0);
}
void add_scaleData(flatbuffers::Offset<flatbuffers::Vector<float>> scaleData) {
fbb_.AddOffset(8, scaleData);
}
void add_zeroPointData(flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData) {
fbb_.AddOffset(10, zeroPointData);
}
explicit QuantizeLinearBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
QuantizeLinearBuilder &operator=(const QuantizeLinearBuilder &);
flatbuffers::Offset<QuantizeLinear> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<QuantizeLinear>(end);
return o;
}
};
inline flatbuffers::Offset<QuantizeLinear> CreateQuantizeLinear(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t scaleSize = 0,
int32_t scaleAxis = 0,
flatbuffers::Offset<flatbuffers::Vector<float>> scaleData = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData = 0) {
QuantizeLinearBuilder builder_(_fbb);
builder_.add_zeroPointData(zeroPointData);
builder_.add_scaleData(scaleData);
builder_.add_scaleAxis(scaleAxis);
builder_.add_scaleSize(scaleSize);
return builder_.Finish();
}
flatbuffers::Offset<QuantizeLinear> CreateQuantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct DequantizeLinearT : public flatbuffers::NativeTable {
typedef DequantizeLinear TableType;
int32_t scaleSize;
int32_t scaleAxis;
std::vector<float> scaleData;
std::vector<int8_t> zeroPointData;
DequantizeLinearT()
: scaleSize(0),
scaleAxis(0) {
}
};
struct DequantizeLinear FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef DequantizeLinearT NativeTableType;
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
return DequantizeLinearTypeTable();
}
int32_t scaleSize() const {
return GetField<int32_t>(4, 0);
}
int32_t scaleAxis() const {
return GetField<int32_t>(6, 0);
}
const flatbuffers::Vector<float> *scaleData() const {
return GetPointer<const flatbuffers::Vector<float> *>(8);
}
const flatbuffers::Vector<int8_t> *zeroPointData() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(10);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, 4) &&
VerifyField<int32_t>(verifier, 6) &&
VerifyOffset(verifier, 8) &&
verifier.VerifyVector(scaleData()) &&
VerifyOffset(verifier, 10) &&
verifier.VerifyVector(zeroPointData()) &&
verifier.EndTable();
}
DequantizeLinearT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(DequantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<DequantizeLinear> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct DequantizeLinearBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_scaleSize(int32_t scaleSize) {
fbb_.AddElement<int32_t>(4, scaleSize, 0);
}
void add_scaleAxis(int32_t scaleAxis) {
fbb_.AddElement<int32_t>(6, scaleAxis, 0);
}
void add_scaleData(flatbuffers::Offset<flatbuffers::Vector<float>> scaleData) {
fbb_.AddOffset(8, scaleData);
}
void add_zeroPointData(flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData) {
fbb_.AddOffset(10, zeroPointData);
}
explicit DequantizeLinearBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
DequantizeLinearBuilder &operator=(const DequantizeLinearBuilder &);
flatbuffers::Offset<DequantizeLinear> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<DequantizeLinear>(end);
return o;
}
};
inline flatbuffers::Offset<DequantizeLinear> CreateDequantizeLinear(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t scaleSize = 0,
int32_t scaleAxis = 0,
flatbuffers::Offset<flatbuffers::Vector<float>> scaleData = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> zeroPointData = 0) {
DequantizeLinearBuilder builder_(_fbb);
builder_.add_zeroPointData(zeroPointData);
builder_.add_scaleData(scaleData);
builder_.add_scaleAxis(scaleAxis);
builder_.add_scaleSize(scaleSize);
return builder_.Finish();
}
flatbuffers::Offset<DequantizeLinear> CreateDequantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct EltwiseT : public flatbuffers::NativeTable { struct EltwiseT : public flatbuffers::NativeTable {
typedef Eltwise TableType; typedef Eltwise TableType;
EltwiseType type; EltwiseType type;
@ -4672,6 +4499,7 @@ inline void QuantizedFloatParam::UnPackTo(QuantizedFloatParamT *_o, const flatbu
{ auto _e = clampMin(); _o->clampMin = _e; }; { auto _e = clampMin(); _o->clampMin = _e; };
{ auto _e = clampMax(); _o->clampMax = _e; }; { auto _e = clampMax(); _o->clampMax = _e; };
{ auto _e = winogradAttr(); if (_e) { _o->winogradAttr.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->winogradAttr[_i] = _e->Get(_i); } } }; { auto _e = winogradAttr(); if (_e) { _o->winogradAttr.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->winogradAttr[_i] = _e->Get(_i); } } };
{ auto _e = outputDataType(); _o->outputDataType = _e; };
} }
inline flatbuffers::Offset<QuantizedFloatParam> QuantizedFloatParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizedFloatParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<QuantizedFloatParam> QuantizedFloatParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizedFloatParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -4693,6 +4521,7 @@ inline flatbuffers::Offset<QuantizedFloatParam> CreateQuantizedFloatParam(flatbu
auto _clampMin = _o->clampMin; auto _clampMin = _o->clampMin;
auto _clampMax = _o->clampMax; auto _clampMax = _o->clampMax;
auto _winogradAttr = _o->winogradAttr.size() ? _fbb.CreateVector(_o->winogradAttr) : 0; auto _winogradAttr = _o->winogradAttr.size() ? _fbb.CreateVector(_o->winogradAttr) : 0;
auto _outputDataType = _o->outputDataType;
return MNN::CreateQuantizedFloatParam( return MNN::CreateQuantizedFloatParam(
_fbb, _fbb,
_weight, _weight,
@ -4705,7 +4534,8 @@ inline flatbuffers::Offset<QuantizedFloatParam> CreateQuantizedFloatParam(flatbu
_outputZeroPoint, _outputZeroPoint,
_clampMin, _clampMin,
_clampMax, _clampMax,
_winogradAttr); _winogradAttr,
_outputDataType);
} }
inline Convolution2DT *Convolution2D::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline Convolution2DT *Convolution2D::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -5342,76 +5172,6 @@ inline flatbuffers::Offset<Scale> CreateScale(flatbuffers::FlatBufferBuilder &_f
_external); _external);
} }
inline QuantizeLinearT *QuantizeLinear::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new QuantizeLinearT();
UnPackTo(_o, _resolver);
return _o;
}
inline void QuantizeLinear::UnPackTo(QuantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = scaleSize(); _o->scaleSize = _e; };
{ auto _e = scaleAxis(); _o->scaleAxis = _e; };
{ auto _e = scaleData(); if (_e) { _o->scaleData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scaleData[_i] = _e->Get(_i); } } };
{ auto _e = zeroPointData(); if (_e) { _o->zeroPointData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zeroPointData[_i] = _e->Get(_i); } } };
}
inline flatbuffers::Offset<QuantizeLinear> QuantizeLinear::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateQuantizeLinear(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<QuantizeLinear> CreateQuantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizeLinearT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _scaleSize = _o->scaleSize;
auto _scaleAxis = _o->scaleAxis;
auto _scaleData = _o->scaleData.size() ? _fbb.CreateVector(_o->scaleData) : 0;
auto _zeroPointData = _o->zeroPointData.size() ? _fbb.CreateVector(_o->zeroPointData) : 0;
return MNN::CreateQuantizeLinear(
_fbb,
_scaleSize,
_scaleAxis,
_scaleData,
_zeroPointData);
}
inline DequantizeLinearT *DequantizeLinear::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new DequantizeLinearT();
UnPackTo(_o, _resolver);
return _o;
}
inline void DequantizeLinear::UnPackTo(DequantizeLinearT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = scaleSize(); _o->scaleSize = _e; };
{ auto _e = scaleAxis(); _o->scaleAxis = _e; };
{ auto _e = scaleData(); if (_e) { _o->scaleData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scaleData[_i] = _e->Get(_i); } } };
{ auto _e = zeroPointData(); if (_e) { _o->zeroPointData.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zeroPointData[_i] = _e->Get(_i); } } };
}
inline flatbuffers::Offset<DequantizeLinear> DequantizeLinear::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateDequantizeLinear(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<DequantizeLinear> CreateDequantizeLinear(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeLinearT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DequantizeLinearT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _scaleSize = _o->scaleSize;
auto _scaleAxis = _o->scaleAxis;
auto _scaleData = _o->scaleData.size() ? _fbb.CreateVector(_o->scaleData) : 0;
auto _zeroPointData = _o->zeroPointData.size() ? _fbb.CreateVector(_o->zeroPointData) : 0;
return MNN::CreateDequantizeLinear(
_fbb,
_scaleSize,
_scaleAxis,
_scaleData,
_zeroPointData);
}
inline EltwiseT *Eltwise::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline EltwiseT *Eltwise::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new EltwiseT(); auto _o = new EltwiseT();
UnPackTo(_o, _resolver); UnPackTo(_o, _resolver);
@ -6243,10 +6003,12 @@ inline const flatbuffers::TypeTable *QuantizedFloatParamTypeTable() {
{ flatbuffers::ET_CHAR, 0, -1 }, { flatbuffers::ET_CHAR, 0, -1 },
{ flatbuffers::ET_CHAR, 0, -1 }, { flatbuffers::ET_CHAR, 0, -1 },
{ flatbuffers::ET_CHAR, 0, -1 }, { flatbuffers::ET_CHAR, 0, -1 },
{ flatbuffers::ET_INT, 1, -1 } { flatbuffers::ET_INT, 1, -1 },
{ flatbuffers::ET_INT, 0, 1 }
}; };
static const flatbuffers::TypeFunction type_refs[] = { static const flatbuffers::TypeFunction type_refs[] = {
QuantizeAlgoTypeTable QuantizeAlgoTypeTable,
DataTypeTypeTable
}; };
static const char * const names[] = { static const char * const names[] = {
"weight", "weight",
@ -6259,10 +6021,11 @@ inline const flatbuffers::TypeTable *QuantizedFloatParamTypeTable() {
"outputZeroPoint", "outputZeroPoint",
"clampMin", "clampMin",
"clampMax", "clampMax",
"winogradAttr" "winogradAttr",
"outputDataType"
}; };
static const flatbuffers::TypeTable tt = { static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 11, type_codes, type_refs, nullptr, names flatbuffers::ST_TABLE, 12, type_codes, type_refs, nullptr, names
}; };
return &tt; return &tt;
} }
@ -6648,44 +6411,6 @@ inline const flatbuffers::TypeTable *ScaleTypeTable() {
return &tt; return &tt;
} }
inline const flatbuffers::TypeTable *QuantizeLinearTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_INT, 0, -1 },
{ flatbuffers::ET_INT, 0, -1 },
{ flatbuffers::ET_FLOAT, 1, -1 },
{ flatbuffers::ET_CHAR, 1, -1 }
};
static const char * const names[] = {
"scaleSize",
"scaleAxis",
"scaleData",
"zeroPointData"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 4, type_codes, nullptr, nullptr, names
};
return &tt;
}
inline const flatbuffers::TypeTable *DequantizeLinearTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_INT, 0, -1 },
{ flatbuffers::ET_INT, 0, -1 },
{ flatbuffers::ET_FLOAT, 1, -1 },
{ flatbuffers::ET_CHAR, 1, -1 }
};
static const char * const names[] = {
"scaleSize",
"scaleAxis",
"scaleData",
"zeroPointData"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 4, type_codes, nullptr, nullptr, names
};
return &tt;
}
inline const flatbuffers::TypeTable *EltwiseTypeTable() { inline const flatbuffers::TypeTable *EltwiseTypeTable() {
static const flatbuffers::TypeCode type_codes[] = { static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_CHAR, 0, 0 }, { flatbuffers::ET_CHAR, 0, 0 },

View File

@ -236,8 +236,6 @@ enum OpType {
OpType_GatherElements = 152, OpType_GatherElements = 152,
OpType_Svd = 153, OpType_Svd = 153,
OpType_Histogram = 154, OpType_Histogram = 154,
OpType_QuantizeLinear = 155,
OpType_DequantizeLinear = 156,
OpType_Plugin = 256, OpType_Plugin = 256,
OpType_Select = 257, OpType_Select = 257,
OpType_ZerosLike = 258, OpType_ZerosLike = 258,
@ -267,7 +265,7 @@ enum OpType {
OpType_MAX = OpType_GridSample OpType_MAX = OpType_GridSample
}; };
inline const OpType (&EnumValuesOpType())[177] { inline const OpType (&EnumValuesOpType())[175] {
static const OpType values[] = { static const OpType values[] = {
OpType_AbsVal, OpType_AbsVal,
OpType_QuantizedAdd, OpType_QuantizedAdd,
@ -419,8 +417,6 @@ inline const OpType (&EnumValuesOpType())[177] {
OpType_GatherElements, OpType_GatherElements,
OpType_Svd, OpType_Svd,
OpType_Histogram, OpType_Histogram,
OpType_QuantizeLinear,
OpType_DequantizeLinear,
OpType_Plugin, OpType_Plugin,
OpType_Select, OpType_Select,
OpType_ZerosLike, OpType_ZerosLike,
@ -607,8 +603,8 @@ inline const char * const *EnumNamesOpType() {
"GatherElements", "GatherElements",
"Svd", "Svd",
"Histogram", "Histogram",
"QuantizeLinear", "",
"DequantizeLinear", "",
"", "",
"", "",
"", "",
@ -1164,13 +1160,11 @@ enum OpParameter {
OpParameter_LoopParam = 92, OpParameter_LoopParam = 92,
OpParameter_ImageProcessParam = 93, OpParameter_ImageProcessParam = 93,
OpParameter_CumSum = 94, OpParameter_CumSum = 94,
OpParameter_QuantizeLinear = 95,
OpParameter_DequantizeLinear = 96,
OpParameter_MIN = OpParameter_NONE, OpParameter_MIN = OpParameter_NONE,
OpParameter_MAX = OpParameter_DequantizeLinear OpParameter_MAX = OpParameter_CumSum
}; };
inline const OpParameter (&EnumValuesOpParameter())[97] { inline const OpParameter (&EnumValuesOpParameter())[95] {
static const OpParameter values[] = { static const OpParameter values[] = {
OpParameter_NONE, OpParameter_NONE,
OpParameter_QuantizedAdd, OpParameter_QuantizedAdd,
@ -1266,9 +1260,7 @@ inline const OpParameter (&EnumValuesOpParameter())[97] {
OpParameter_GridSample, OpParameter_GridSample,
OpParameter_LoopParam, OpParameter_LoopParam,
OpParameter_ImageProcessParam, OpParameter_ImageProcessParam,
OpParameter_CumSum, OpParameter_CumSum
OpParameter_QuantizeLinear,
OpParameter_DequantizeLinear
}; };
return values; return values;
} }
@ -1370,15 +1362,13 @@ inline const char * const *EnumNamesOpParameter() {
"LoopParam", "LoopParam",
"ImageProcessParam", "ImageProcessParam",
"CumSum", "CumSum",
"QuantizeLinear",
"DequantizeLinear",
nullptr nullptr
}; };
return names; return names;
} }
inline const char *EnumNameOpParameter(OpParameter e) { inline const char *EnumNameOpParameter(OpParameter e) {
if (e < OpParameter_NONE || e > OpParameter_DequantizeLinear) return ""; if (e < OpParameter_NONE || e > OpParameter_CumSum) return "";
const size_t index = static_cast<int>(e); const size_t index = static_cast<int>(e);
return EnumNamesOpParameter()[index]; return EnumNamesOpParameter()[index];
} }
@ -1763,14 +1753,6 @@ template<> struct OpParameterTraits<CumSum> {
static const OpParameter enum_value = OpParameter_CumSum; static const OpParameter enum_value = OpParameter_CumSum;
}; };
template<> struct OpParameterTraits<QuantizeLinear> {
static const OpParameter enum_value = OpParameter_QuantizeLinear;
};
template<> struct OpParameterTraits<DequantizeLinear> {
static const OpParameter enum_value = OpParameter_DequantizeLinear;
};
struct OpParameterUnion { struct OpParameterUnion {
OpParameter type; OpParameter type;
void *value; void *value;
@ -2554,22 +2536,6 @@ struct OpParameterUnion {
return type == OpParameter_CumSum ? return type == OpParameter_CumSum ?
reinterpret_cast<const CumSumT *>(value) : nullptr; reinterpret_cast<const CumSumT *>(value) : nullptr;
} }
QuantizeLinearT *AsQuantizeLinear() {
return type == OpParameter_QuantizeLinear ?
reinterpret_cast<QuantizeLinearT *>(value) : nullptr;
}
const QuantizeLinearT *AsQuantizeLinear() const {
return type == OpParameter_QuantizeLinear ?
reinterpret_cast<const QuantizeLinearT *>(value) : nullptr;
}
DequantizeLinearT *AsDequantizeLinear() {
return type == OpParameter_DequantizeLinear ?
reinterpret_cast<DequantizeLinearT *>(value) : nullptr;
}
const DequantizeLinearT *AsDequantizeLinear() const {
return type == OpParameter_DequantizeLinear ?
reinterpret_cast<const DequantizeLinearT *>(value) : nullptr;
}
}; };
bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type); bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type);
@ -3633,12 +3599,6 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const CumSum *main_as_CumSum() const { const CumSum *main_as_CumSum() const {
return main_type() == OpParameter_CumSum ? static_cast<const CumSum *>(main()) : nullptr; return main_type() == OpParameter_CumSum ? static_cast<const CumSum *>(main()) : nullptr;
} }
const QuantizeLinear *main_as_QuantizeLinear() const {
return main_type() == OpParameter_QuantizeLinear ? static_cast<const QuantizeLinear *>(main()) : nullptr;
}
const DequantizeLinear *main_as_DequantizeLinear() const {
return main_type() == OpParameter_DequantizeLinear ? static_cast<const DequantizeLinear *>(main()) : nullptr;
}
const flatbuffers::String *name() const { const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(10); return GetPointer<const flatbuffers::String *>(10);
} }
@ -4047,14 +4007,6 @@ template<> inline const CumSum *Op::main_as<CumSum>() const {
return main_as_CumSum(); return main_as_CumSum();
} }
template<> inline const QuantizeLinear *Op::main_as<QuantizeLinear>() const {
return main_as_QuantizeLinear();
}
template<> inline const DequantizeLinear *Op::main_as<DequantizeLinear>() const {
return main_as_DequantizeLinear();
}
struct OpBuilder { struct OpBuilder {
flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_; flatbuffers::uoffset_t start_;
@ -5676,14 +5628,6 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj,
auto ptr = reinterpret_cast<const CumSum *>(obj); auto ptr = reinterpret_cast<const CumSum *>(obj);
return verifier.VerifyTable(ptr); return verifier.VerifyTable(ptr);
} }
case OpParameter_QuantizeLinear: {
auto ptr = reinterpret_cast<const QuantizeLinear *>(obj);
return verifier.VerifyTable(ptr);
}
case OpParameter_DequantizeLinear: {
auto ptr = reinterpret_cast<const DequantizeLinear *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false; default: return false;
} }
} }
@ -6078,14 +6022,6 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f
auto ptr = reinterpret_cast<const CumSum *>(obj); auto ptr = reinterpret_cast<const CumSum *>(obj);
return ptr->UnPack(resolver); return ptr->UnPack(resolver);
} }
case OpParameter_QuantizeLinear: {
auto ptr = reinterpret_cast<const QuantizeLinear *>(obj);
return ptr->UnPack(resolver);
}
case OpParameter_DequantizeLinear: {
auto ptr = reinterpret_cast<const DequantizeLinear *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr; default: return nullptr;
} }
} }
@ -6468,14 +6404,6 @@ inline flatbuffers::Offset<void> OpParameterUnion::Pack(flatbuffers::FlatBufferB
auto ptr = reinterpret_cast<const CumSumT *>(value); auto ptr = reinterpret_cast<const CumSumT *>(value);
return CreateCumSum(_fbb, ptr, _rehasher).Union(); return CreateCumSum(_fbb, ptr, _rehasher).Union();
} }
case OpParameter_QuantizeLinear: {
auto ptr = reinterpret_cast<const QuantizeLinearT *>(value);
return CreateQuantizeLinear(_fbb, ptr, _rehasher).Union();
}
case OpParameter_DequantizeLinear: {
auto ptr = reinterpret_cast<const DequantizeLinearT *>(value);
return CreateDequantizeLinear(_fbb, ptr, _rehasher).Union();
}
default: return 0; default: return 0;
} }
} }
@ -6858,14 +6786,6 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS
value = new CumSumT(*reinterpret_cast<CumSumT *>(u.value)); value = new CumSumT(*reinterpret_cast<CumSumT *>(u.value));
break; break;
} }
case OpParameter_QuantizeLinear: {
value = new QuantizeLinearT(*reinterpret_cast<QuantizeLinearT *>(u.value));
break;
}
case OpParameter_DequantizeLinear: {
value = new DequantizeLinearT(*reinterpret_cast<DequantizeLinearT *>(u.value));
break;
}
default: default:
break; break;
} }
@ -7343,16 +7263,6 @@ inline void OpParameterUnion::Reset() {
delete ptr; delete ptr;
break; break;
} }
case OpParameter_QuantizeLinear: {
auto ptr = reinterpret_cast<QuantizeLinearT *>(value);
delete ptr;
break;
}
case OpParameter_DequantizeLinear: {
auto ptr = reinterpret_cast<DequantizeLinearT *>(value);
delete ptr;
break;
}
default: break; default: break;
} }
value = nullptr; value = nullptr;
@ -7535,14 +7445,12 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
{ flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 } { flatbuffers::ET_INT, 0, 0 }
}; };
static const flatbuffers::TypeFunction type_refs[] = { static const flatbuffers::TypeFunction type_refs[] = {
OpTypeTypeTable OpTypeTypeTable
}; };
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 }; static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
static const char * const names[] = { static const char * const names[] = {
"AbsVal", "AbsVal",
"QuantizedAdd", "QuantizedAdd",
@ -7694,8 +7602,6 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"GatherElements", "GatherElements",
"Svd", "Svd",
"Histogram", "Histogram",
"QuantizeLinear",
"DequantizeLinear",
"Plugin", "Plugin",
"Select", "Select",
"ZerosLike", "ZerosLike",
@ -7723,7 +7629,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"GridSample" "GridSample"
}; };
static const flatbuffers::TypeTable tt = { static const flatbuffers::TypeTable tt = {
flatbuffers::ST_ENUM, 177, type_codes, type_refs, values, names flatbuffers::ST_ENUM, 175, type_codes, type_refs, values, names
}; };
return &tt; return &tt;
} }
@ -7824,9 +7730,7 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
{ flatbuffers::ET_SEQUENCE, 0, 90 }, { flatbuffers::ET_SEQUENCE, 0, 90 },
{ flatbuffers::ET_SEQUENCE, 0, 91 }, { flatbuffers::ET_SEQUENCE, 0, 91 },
{ flatbuffers::ET_SEQUENCE, 0, 92 }, { flatbuffers::ET_SEQUENCE, 0, 92 },
{ flatbuffers::ET_SEQUENCE, 0, 93 }, { flatbuffers::ET_SEQUENCE, 0, 93 }
{ flatbuffers::ET_SEQUENCE, 0, 94 },
{ flatbuffers::ET_SEQUENCE, 0, 95 }
}; };
static const flatbuffers::TypeFunction type_refs[] = { static const flatbuffers::TypeFunction type_refs[] = {
QuantizedAddTypeTable, QuantizedAddTypeTable,
@ -7922,9 +7826,7 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
GridSampleTypeTable, GridSampleTypeTable,
LoopParamTypeTable, LoopParamTypeTable,
ImageProcessParamTypeTable, ImageProcessParamTypeTable,
CumSumTypeTable, CumSumTypeTable
QuantizeLinearTypeTable,
DequantizeLinearTypeTable
}; };
static const char * const names[] = { static const char * const names[] = {
"NONE", "NONE",
@ -8021,12 +7923,10 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
"GridSample", "GridSample",
"LoopParam", "LoopParam",
"ImageProcessParam", "ImageProcessParam",
"CumSum", "CumSum"
"QuantizeLinear",
"DequantizeLinear"
}; };
static const flatbuffers::TypeTable tt = { static const flatbuffers::TypeTable tt = {
flatbuffers::ST_UNION, 97, type_codes, type_refs, nullptr, names flatbuffers::ST_UNION, 95, type_codes, type_refs, nullptr, names
}; };
return &tt; return &tt;
} }

View File

@ -95,6 +95,7 @@ table QuantizedFloatParam{
clampMax: byte = 127; clampMax: byte = 127;
// binary proto: [originKySize, originKxSize, transKySize, transKxSize, {kyStart, kxStart, unitY, unitX}, {...} ...] // binary proto: [originKySize, originKxSize, transKySize, transKxSize, {kyStart, kxStart, unitY, unitX}, {...} ...]
winogradAttr:[int]; winogradAttr:[int];
outputDataType:DataType=DT_INT8;
} }
table Convolution2D { table Convolution2D {
@ -247,20 +248,6 @@ table Scale {
external:[int64]; // [offset, scaleData_bytes_size, biasData_bytes_size] external:[int64]; // [offset, scaleData_bytes_size, biasData_bytes_size]
} }
table QuantizeLinear {
scaleSize: int;
scaleAxis: int;
scaleData:[float];
zeroPointData:[byte];
}
table DequantizeLinear {
scaleSize: int;
scaleAxis: int;
scaleData:[float];
zeroPointData:[byte];
}
enum EltwiseType : byte { enum EltwiseType : byte {
PROD = 0, PROD = 0,
SUM = 1, SUM = 1,

View File

@ -167,8 +167,6 @@ enum OpType : int {
GatherElements = 152, GatherElements = 152,
Svd = 153, Svd = 153,
Histogram = 154, Histogram = 154,
QuantizeLinear = 155,
DequantizeLinear = 156,
Plugin = 256, //The Type load from plugin Plugin = 256, //The Type load from plugin
//Training Op Start from 257 //Training Op Start from 257
@ -392,8 +390,6 @@ union OpParameter {
LoopParam, LoopParam,
ImageProcessParam, ImageProcessParam,
CumSum, CumSum,
QuantizeLinear,
DequantizeLinear,
} }
table Op { table Op {

View File

@ -62,10 +62,12 @@ vadd.f16 q3, q3, q1
vmul.f16 q2, q2, q14 vmul.f16 q2, q2, q14
vmul.f16 q3, q3, q14 vmul.f16 q3, q3, q14
mov lr, #5.0 mov lr, #5
vdup.16 q4, lr vdup.16 q4, lr
mov lr, #-5.0 vcvt.f32.s32 q4, q4
mov lr, #-5
vdup.16 q5, lr vdup.16 q5, lr
vcvt.f32.s32 q5, q5
vmax.f16 q2, q2, q5 vmax.f16 q2, q2, q5
vmin.f16 q2, q2, q4 vmin.f16 q2, q2, q4
vmax.f16 q3, q3, q5 vmax.f16 q3, q3, q5

View File

@ -45,8 +45,8 @@ dup v10.8h, w9 // v10: [28.f]x4
dup v9.8h, w10 // v9: [3150.f]x4 dup v9.8h, w10 // v9: [3150.f]x4
dup v8.8h, w11 // v8: [62370.f]x4 dup v8.8h, w11 // v8: [62370.f]x4
mov w4, #5.0 mov w4, #5
mov w5, #-5.0 mov w5, #-5
GeluZLoop: GeluZLoop:
@ -67,6 +67,8 @@ fmul v3.8h, v3.8h, v14.8h
dup v6.8h, w5 dup v6.8h, w5
dup v7.8h, w4 dup v7.8h, w4
scvtf v6.8h, v6.8h
scvtf v7.8h, v7.8h
fmin v2.8h, v2.8h, v7.8h fmin v2.8h, v2.8h, v7.8h
fmin v3.8h, v3.8h, v7.8h fmin v3.8h, v3.8h, v7.8h
fmax v2.8h, v2.8h, v6.8h fmax v2.8h, v2.8h, v6.8h

View File

@ -195,6 +195,9 @@ Execution *CPUCastCreator::onCreate(const std::vector<Tensor *> &inputs, const s
if (dstT == MNN::DataType_DT_INT8 && halide_type_of<float>() == inputDataType) { if (dstT == MNN::DataType_DT_INT8 && halide_type_of<float>() == inputDataType) {
return new CastDataType<float, int8_t>(backend); return new CastDataType<float, int8_t>(backend);
} }
if (dstT == MNN::DataType_DT_INT8 && halide_type_of<int32_t>() == inputDataType) {
return new CastDataType<int32_t, int8_t>(backend);
}
if (dstT == MNN::DataType_DT_UINT8 && halide_type_of<float>() == inputDataType) { if (dstT == MNN::DataType_DT_UINT8 && halide_type_of<float>() == inputDataType) {
return new CastDataType<float, uint8_t>(backend); return new CastDataType<float, uint8_t>(backend);
} }

View File

@ -1,87 +0,0 @@
//
// CPUDequantizeLinear.cpp
// MNN
//
// Created by MNN on 2018/07/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "backend/cpu/CPUBackend.hpp"
#include "core/Concurrency.h"
#include "backend/cpu/CPUDequantizeLinear.hpp"
#include "core/TensorUtils.hpp"
#include "compute/CommonOptFunction.h"
namespace MNN {
CPUDequantizeLinear::CPUDequantizeLinear(Backend *b, float* scale, int8_t* zeroPoints, int size, int axis, int inputBits) : MNN::Execution(b){
mSize = size;
mAxis = axis;
mInputBits = inputBits;
}
ErrorCode CPUDequantizeLinear::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
if (mInputBits == 8) {
mFunc = dequantizeFunc<int8_t>;
} else if (mInputBits == 16) {
mFunc = dequantizeFunc<int16_t>;
} else {
mFunc = dequantizeFunc<int32_t>;
}
float *scale = inputs[1]->host<float>();
int8_t *zero = nullptr;
if (inputs.size() > 2) {
zero = inputs[2]->host<int8_t>();;
}
if (mSize == 1) {
mQuantScales.resize(4, *scale);
if (nullptr != zero) {
mQuantZeroPoints.resize(4, *zero);
} else {
mQuantZeroPoints.resize(4, 0);
}
} else {
mQuantScales.resize(mSize);
::memcpy(mQuantScales.data(), scale, sizeof(float) * mSize);
if (nullptr != zero) {
mQuantZeroPoints.resize(mSize);
::memcpy(mQuantZeroPoints.data(), zero, mSize);
} else {
mQuantZeroPoints.resize(mSize);
}
}
return NO_ERROR;
}
ErrorCode CPUDequantizeLinear::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto input = inputs[0];
int N = input->length(0);
ssize_t size = N;
auto core = static_cast<CPUBackend*>(backend())->int8Functions();
int UNIT, SRC_UNIT, DST_XUNIT;
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
auto dst = outputs[0]->host<float>();
auto src = input->host<int8_t>();
mFunc(dst, src, input->dimensions(), input->size(), mSize, UNIT, mQuantScales.data(), mQuantZeroPoints.data(), core);
return NO_ERROR;
}
class CPUDequantizeLinearCreator : public CPUBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
auto dataType = inputs[0]->getType();
if (dataType.bits != 8 && dataType.bits != 16 && dataType.bits != 32) {
MNN_ERROR("Input of Dequantize must be int8/uint8/fp16/int32\n");
return nullptr;
}
int inputBits = dataType.bits;
int size = op->main_as_DequantizeLinear()->scaleSize();
int axis = op->main_as_DequantizeLinear()->scaleAxis();
if (inputs.size() > 2) {
return new CPUDequantizeLinear(backend, inputs[1]->host<float>(), inputs[2]->host<int8_t>(), size, axis, inputBits);
}
return new CPUDequantizeLinear(backend, inputs[1]->host<float>(), nullptr, size, axis, inputBits);
}
};
REGISTER_CPU_OP_CREATOR(CPUDequantizeLinearCreator, OpType_DequantizeLinear);
} // namespace MNN

View File

@ -1,81 +0,0 @@
//
// CPUDequantizeLinear.hpp
// MNN
//
// Created by MNN on 2018/07/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef CPUDequantizeLinear_hpp
#define CPUDequantizeLinear_hpp
#include "core/AutoStorage.h"
#include "core/Execution.hpp"
#include "compute/Int8FunctionsOpt.h"
namespace MNN {
typedef void(*dequantFunc)(float* dst, const int8_t* source, int inputDim, int inputSize, int size, int UNIT, float* scales, int8_t* zeros, const CoreInt8Functions* core);
class CPUDequantizeLinear : public Execution {
public:
CPUDequantizeLinear(Backend *b, float* scales, int8_t* zeroPoints, int size = 1, int axis = 0, int inputBits = 8);
virtual ~CPUDequantizeLinear() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private:
std::vector<float> mQuantScales;
std::vector<int8_t> mQuantZeroPoints;
int mSize = 1;
int mAxis = 0;
int mInputBits = 8;
dequantFunc mFunc;
};
template<typename T>
void dequantizeFunc(float* dst, const int8_t* source, int inputDim, int inputSize, int size, int UNIT, float* scales, int8_t* zeros, const CoreInt8Functions* core) {
#ifdef MNN_USE_SSE
auto src = (uint8_t*)source;
int offset = 128;
#else
auto src = (int8_t*)source;
int offset = 0;
#endif
// auto src = (T*)source;
if (inputDim == 1) {
for (int i = 0; i < size; ++i) {
dst[i] = static_cast<float>(src[i] - zeros[i] - offset) * scales[i];
}
return;
}
int chw = 1;
if (inputDim > 1) {
chw = inputSize / (size * sizeof(T));
}
if (size == 1) {
if (sizeof(T) == 1) {
core->MNNInt8ScaleToFloat(dst, (int8_t*)src, scales, chw / UNIT, zeros[0]);
int sizeDiv = (int)chw / UNIT;
for (int k = sizeDiv * UNIT; k < chw; ++k) {
dst[k] = static_cast<float>(src[k] - zeros[0] - offset) * scales[0];
}
} else {
for (int k = 0; k < chw; ++k) {
dst[k] = static_cast<float>(src[k] - zeros[0] - offset) * scales[0];
}
}
} else {
for (int i = 0; i < size; ++i) {
std::vector<float> tmp(4, scales[i]);
//core->MNNInt8ScaleToFloat(dst, src, tmp.data(), sizeDiv, mQuantZeroPoints[i]);
for (int k = 0; k < chw; ++k) {
dst[k] = static_cast<float>(src[k] - zeros[i] - offset) * scales[i];
}
src += chw;
dst += chw;
}
}
}
} // namespace MNN
#endif /* CPUDequantizeLinear_hpp */

View File

@ -66,8 +66,6 @@ extern void ___CPUSetDiff1DCreator__OpType_SetDiff1D__();
extern void ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__(); extern void ___CPUEltwiseInt8Creator__OpType_EltwiseInt8__();
extern void ___CPUSvdCreator__OpType_Svd__(); extern void ___CPUSvdCreator__OpType_Svd__();
extern void ___CPULayerNormCreator__OpType_LayerNorm__(); extern void ___CPULayerNormCreator__OpType_LayerNorm__();
extern void ___CPUQuantizeLinearCreator__OpType_QuantizeLinear__();
extern void ___CPUDequantizeLinearCreator__OpType_DequantizeLinear__();
#ifdef MNN_SUPPORT_RENDER #ifdef MNN_SUPPORT_RENDER
extern void ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__(); extern void ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
@ -146,8 +144,5 @@ ___CPURasterAndInterpolateCreator__OpType_RasterAndInterpolate__();
___CPURasterDiffCreator__OpType_RasterDiff__(); ___CPURasterDiffCreator__OpType_RasterDiff__();
___CPUTextureCreator__OpType_Texture__(); ___CPUTextureCreator__OpType_Texture__();
#endif #endif
___CPUQuantizeLinearCreator__OpType_QuantizeLinear__();
___CPUDequantizeLinearCreator__OpType_DequantizeLinear__();
//CPUQuantizeLinearCreator
} }
} }

View File

@ -1,85 +0,0 @@
//
// CPUQuantizeLinear.cpp
// MNN
//
// Created by MNN on 2018/07/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "backend/cpu/CPUBackend.hpp"
#include "core/Concurrency.h"
#include "backend/cpu/CPUQuantizeLinear.hpp"
#include "compute/CommonOptFunction.h"
#include "core/TensorUtils.hpp"
namespace MNN {
CPUQuantizeLinear::CPUQuantizeLinear(Backend *b, int size, int axis) : MNN::Execution(b){
mSize = size;
mAxis = axis;
}
ErrorCode CPUQuantizeLinear::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
int size = mSize;
float* scale = inputs[1]->host<float>();
int8_t* zero = nullptr;
if (inputs.size() > 2) {
zero = inputs[2]->host<int8_t>();
}
if (mSize == 1) {
float s = scale[0] == 0?0: 1/ scale[0];
mQuantScales.resize(4, s);
if (nullptr != zero) {
int8_t z = *zero;
mQuantZeroPoints.resize(4, z);
} else {
mQuantZeroPoints.resize(4);
}
} else { // TODO scale: (1,D)
}
return NO_ERROR;
}
ErrorCode CPUQuantizeLinear::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto input = inputs[0];
int N = input->length(0), C = input->length(1), H = input->length(2), W = input->length(3);
ssize_t size = N * C * H * W;
auto core = static_cast<CPUBackend*>(backend())->int8Functions();
int UNIT, SRC_UNIT, DST_XUNIT;
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
int maxValue = 127;
int minValue = -128;
#ifdef MNN_USE_SSE
auto dst = outputs[0]->host<uint8_t>();
int offset = 128;
#else
auto dst = outputs[0]->host<int8_t>();
int offset = 0;
#endif
if (mSize == 1) {
auto src = input->host<float>();
int sizeDiv = (int)size / UNIT;
core->MNNFloat2Int8(src, (int8_t*)dst, size / UNIT, mQuantScales.data(), -128, 127, mQuantZeroPoints[0]);
for (int i = sizeDiv * UNIT; i < size; ++i) {
int v = (int)roundf(src[i] * mQuantScales[0]) + mQuantZeroPoints[0] + offset;
v = std::max(minValue + offset, std::min(maxValue + offset, v));
dst[i] = v;
}
} else {
}
return NO_ERROR;
}
class CPUQuantizeLinearCreator : public CPUBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
int size = op->main_as_QuantizeLinear()->scaleSize();
int axis = op->main_as_QuantizeLinear()->scaleAxis();
return new CPUQuantizeLinear(backend, size, axis);
}
};
REGISTER_CPU_OP_CREATOR(CPUQuantizeLinearCreator, OpType_QuantizeLinear);
} // namespace MNN

View File

@ -1,31 +0,0 @@
//
// CPUQuantizeLinear.hpp
// MNN
//
// Created by MNN on 2018/07/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef CPUQuantizeLinear_hpp
#define CPUQuantizeLinear_hpp
#include "core/AutoStorage.h"
#include "core/Execution.hpp"
namespace MNN {
class CPUQuantizeLinear : public Execution {
public:
CPUQuantizeLinear(Backend *b, int size = 1, int axis = 0);
virtual ~CPUQuantizeLinear() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private:
std::vector<float> mQuantScales;
std::vector<int8_t> mQuantZeroPoints;
int mSize = 1;
int mAxis = 0;
};
} // namespace MNN
#endif /* CPUQuantizeLinear_hpp */

View File

@ -884,14 +884,14 @@ public:
// Loop Op's command's first index must be output // Loop Op's command's first index must be output
outputStride = cmd->view()->GetAs<View>(0)->stride()->data(); outputStride = cmd->view()->GetAs<View>(0)->stride()->data();
} }
halide_type_t outputType; halide_type_t inputType;
for (int v=0; v<iterIndexsize; ++v) { for (int v=0; v<iterIndexsize; ++v) {
auto tensorIndex = cmd->indexes()->data()[v]; auto tensorIndex = cmd->indexes()->data()[v];
auto tensor = mStack[tensorIndex]; auto tensor = mStack[tensorIndex];
auto iterIndex = cmd->iterIndexes()->data()[v]; auto iterIndex = cmd->iterIndexes()->data()[v];
auto offset = iter; auto offset = iter;
if (0 == v) { if (1 == v) {
outputType = tensor->getType(); inputType = tensor->getType();
} }
if (iterIndex >= 0) { if (iterIndex >= 0) {
offset = mStack[iterIndex]->host<int32_t>()[iter]; offset = mStack[iterIndex]->host<int32_t>()[iter];
@ -969,10 +969,10 @@ public:
if (OpType_BinaryOp == op->type()) { if (OpType_BinaryOp == op->type()) {
auto src0 = mContainer[tId].stackPtr[cmd->indexes()->data()[1]]; auto src0 = mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
MNNBinaryExecute proc; MNNBinaryExecute proc;
if (outputType.code == halide_type_float) { if (inputType.code == halide_type_float) {
proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectBinaryFunctionForFloat(op->main_as_BinaryOp()->opType()); proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectBinaryFunctionForFloat(op->main_as_BinaryOp()->opType());
} else { } else {
MNN_ASSERT(outputType.code == halide_type_int); MNN_ASSERT(inputType.code == halide_type_int);
proc = CPUBinary::selectForInt(op->main_as_BinaryOp()->opType()); proc = CPUBinary::selectForInt(op->main_as_BinaryOp()->opType());
} }
auto lastS = cmd->size()->data()[2]; auto lastS = cmd->size()->data()[2];

View File

@ -1531,6 +1531,8 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) {
cpu_family == CPUFAMILY_ARM_EVEREST_SAWTOOTH || cpu_family == CPUFAMILY_ARM_EVEREST_SAWTOOTH ||
cpu_family == CPUFAMILY_ARM_PCORE_ECORE_COLL; cpu_family == CPUFAMILY_ARM_PCORE_ECORE_COLL;
cpuinfo_isa->i8mm = cpu_family == CPUFAMILY_ARM_EVEREST_SAWTOOTH ||
cpu_family == CPUFAMILY_ARM_PCORE_ECORE_COLL;
#endif // iOS #endif // iOS
// arm64-osx // arm64-osx

View File

@ -45,10 +45,59 @@ ErrorCode CPUUnary::onResize(const std::vector<Tensor *> &inputs, const std::vec
static void _Neg(void* out, const void* inp, int realSize) { static void _Neg(void* out, const void* inp, int realSize) {
MNNScaleAndAddBiasScalar((float*)out, (const float*)inp, 0.0f, -1.0f, realSize); MNNScaleAndAddBiasScalar((float*)out, (const float*)inp, 0.0f, -1.0f, realSize);
} }
#ifdef MNN_USE_NEON
static inline void exeNegInt8 (int8_t* out, const int8_t* inp, int sizeQuad, int8x8_t inZeroPoint, int8x8_t outZeroPoint, float32x4_t inpScale, float32x4_t outScale) {
for (int i = 0;i < sizeQuad; ++i) {
int8x16_t negValue = vld1q_s8(inp);
int16x8_t val16_0 = vmovl_s8(vget_low_s8(negValue));
int16x8_t val16_1 = vmovl_s8(vget_high_s8(negValue));
val16_0 = vsubw_s8(val16_0, inZeroPoint);
val16_1 = vsubw_s8(val16_1, inZeroPoint);
int32x4_t val32_00 = vmovl_s16(vget_low_s16(val16_0));
int32x4_t val32_01 = vmovl_s16(vget_high_s16(val16_0));
int32x4_t val32_10 = vmovl_s16(vget_low_s16(val16_1));
int32x4_t val32_11 = vmovl_s16(vget_high_s16(val16_1));
float32x4_t valF_00 = vcvtq_f32_s32(val32_00);
float32x4_t valF_01 = vcvtq_f32_s32(val32_01);
float32x4_t valF_10 = vcvtq_f32_s32(val32_10);
float32x4_t valF_11 = vcvtq_f32_s32(val32_11);
valF_00 = vmulq_f32(valF_00, inpScale);
valF_01 = vmulq_f32(valF_01, inpScale);
valF_10 = vmulq_f32(valF_10, inpScale);
valF_11 = vmulq_f32(valF_11, inpScale);
valF_00 = vnegq_f32(valF_00);
valF_01 = vnegq_f32(valF_01);
valF_10 = vnegq_f32(valF_10);
valF_11 = vnegq_f32(valF_11);
valF_00 = vmulq_f32(valF_00, outScale);
valF_01 = vmulq_f32(valF_01, outScale);
valF_10 = vmulq_f32(valF_10, outScale);
valF_11 = vmulq_f32(valF_11, outScale);
int32x4_t val_00 = vcvtq_s32_f32(valF_00);
int32x4_t val_01 = vcvtq_s32_f32(valF_01);
int32x4_t val_10 = vcvtq_s32_f32(valF_10);
int32x4_t val_11 = vcvtq_s32_f32(valF_11);
int16x4_t v16_0 = vqmovn_s32(val_00);
int16x4_t v16_1 = vqmovn_s32(val_01);
int16x4_t v16_2 = vqmovn_s32(val_10);
int16x4_t v16_3 = vqmovn_s32(val_11);
int16x8_t v16_4 = vcombine_s16(v16_0, v16_1);
int16x8_t v16_5 = vcombine_s16(v16_2, v16_3);
v16_4 = vaddw_s8(v16_4, outZeroPoint);
v16_5 = vaddw_s8(v16_5, outZeroPoint);
int8x8_t v8_0 = vqmovn_s16(v16_4);
int8x8_t v8_1 = vqmovn_s16(v16_5);
vst1_s8(out, v8_0);
vst1_s8(out + 8, v8_1);
inp += 16;
out += 16;
}
}
#endif
static void _NegInt8(void* out, const void* inp, int realSize, QuanPrePostParameters* params) { static void _NegInt8(void* out, const void* inp, int realSize, QuanPrePostParameters* params) {
int sizeDiv16 = realSize / 16; int sizeDiv16 = realSize / 16;
int start = 0; int remain = realSize % 16;
#ifdef MNN_USE_NEON #ifdef MNN_USE_NEON
int8_t* outPtr = (int8_t*)out; int8_t* outPtr = (int8_t*)out;
int8_t* inPtr = (int8_t*)inp; int8_t* inPtr = (int8_t*)inp;
@ -57,55 +106,16 @@ static void _NegInt8(void* out, const void* inp, int realSize, QuanPrePostParame
float32x4_t inpScale = vdupq_n_f32(params->inputScale[0]); float32x4_t inpScale = vdupq_n_f32(params->inputScale[0]);
float32x4_t outScale = vdupq_n_f32(params->outputScale[0]); float32x4_t outScale = vdupq_n_f32(params->outputScale[0]);
if (sizeDiv16 > 0) { if (sizeDiv16 > 0) {
for (int i = 0;i < sizeDiv16; ++i) { exeNegInt8(outPtr, inPtr, sizeDiv16, inZeroPoint, outZeroPoint, inpScale, outScale);
int8x16_t negValue = vld1q_s8(inPtr);
int16x8_t val16_0 = vmovl_s8(vget_low_s8(negValue));
int16x8_t val16_1 = vmovl_s8(vget_high_s8(negValue));
val16_0 = vsubw_s8(val16_0, inZeroPoint);
val16_1 = vsubw_s8(val16_1, inZeroPoint);
int32x4_t val32_00 = vmovl_s16(vget_low_s16(val16_0));
int32x4_t val32_01 = vmovl_s16(vget_high_s16(val16_0));
int32x4_t val32_10 = vmovl_s16(vget_low_s16(val16_1));
int32x4_t val32_11 = vmovl_s16(vget_high_s16(val16_1));
float32x4_t valF_00 = vcvtq_f32_s32(val32_00);
float32x4_t valF_01 = vcvtq_f32_s32(val32_01);
float32x4_t valF_10 = vcvtq_f32_s32(val32_10);
float32x4_t valF_11 = vcvtq_f32_s32(val32_11);
valF_00 = vmulq_f32(valF_00, inpScale);
valF_01 = vmulq_f32(valF_01, inpScale);
valF_10 = vmulq_f32(valF_10, inpScale);
valF_11 = vmulq_f32(valF_11, inpScale);
valF_00 = vnegq_f32(valF_00);
valF_01 = vnegq_f32(valF_01);
valF_10 = vnegq_f32(valF_10);
valF_11 = vnegq_f32(valF_11);
valF_00 = vmulq_f32(valF_00, outScale);
valF_01 = vmulq_f32(valF_01, outScale);
valF_10 = vmulq_f32(valF_10, outScale);
valF_11 = vmulq_f32(valF_11, outScale);
int32x4_t val_00 = vcvtq_s32_f32(valF_00);
int32x4_t val_01 = vcvtq_s32_f32(valF_01);
int32x4_t val_10 = vcvtq_s32_f32(valF_10);
int32x4_t val_11 = vcvtq_s32_f32(valF_11);
int16x4_t v16_0 = vqmovn_s32(val_00);
int16x4_t v16_1 = vqmovn_s32(val_01);
int16x4_t v16_2 = vqmovn_s32(val_10);
int16x4_t v16_3 = vqmovn_s32(val_11);
int16x8_t v16_4 = vcombine_s16(v16_0, v16_1);
int16x8_t v16_5 = vcombine_s16(v16_2, v16_3);
v16_4 = vaddw_s8(v16_4, outZeroPoint);
v16_5 = vaddw_s8(v16_5, outZeroPoint);
int8x8_t v8_0 = vqmovn_s16(v16_4);
int8x8_t v8_1 = vqmovn_s16(v16_5);
vst1_s8(outPtr, v8_0);
vst1_s8(outPtr + 8, v8_1);
inPtr += 16;
outPtr += 16;
}
start = 16 * sizeDiv16;
} }
#endif if (remain > 0) {
int8_t intmp[16] = {0};
int8_t outmp[16] = {0};
::memcpy(intmp, reinterpret_cast<const int8_t*>(inp) + 16 * sizeDiv16, remain * sizeof(int8_t));
exeNegInt8(outmp, intmp, 1, inZeroPoint, outZeroPoint, inpScale, outScale);
::memcpy(reinterpret_cast<int8_t*>(out) + 16 * sizeDiv16, outmp, remain * sizeof(int8_t));
}
#else
#ifdef MNN_USE_SSE #ifdef MNN_USE_SSE
uint8_t* dst = (uint8_t*)out; uint8_t* dst = (uint8_t*)out;
uint8_t* src = (uint8_t*)inp; uint8_t* src = (uint8_t*)inp;
@ -121,7 +131,7 @@ static void _NegInt8(void* out, const void* inp, int realSize, QuanPrePostParame
float outscale_ = params->outputScale[0]; float outscale_ = params->outputScale[0];
int min_ = static_cast<int>(params->minValue); int min_ = static_cast<int>(params->minValue);
int max_ = static_cast<int>(params->maxValue); int max_ = static_cast<int>(params->maxValue);
for (int i = start; i < realSize; ++i) { for (int i = 0; i < realSize; ++i) {
int value = -(src[i] - inzero_ - offset) * inscale_ * outscale_ + outzero_; int value = -(src[i] - inzero_ - offset) * inscale_ * outscale_ + outzero_;
if (value > max_) { if (value > max_) {
value = max_; value = max_;
@ -131,14 +141,65 @@ static void _NegInt8(void* out, const void* inp, int realSize, QuanPrePostParame
} }
dst[i] = value + offset; dst[i] = value + offset;
} }
#endif
} }
static void _ABS(void* out, const void* inp, int realSize) { static void _ABS(void* out, const void* inp, int realSize) {
MNNReluWithSlopeCommon((float*)out, (const float*)inp, realSize, -1.0f); MNNReluWithSlopeCommon((float*)out, (const float*)inp, realSize, -1.0f);
} }
#ifdef MNN_USE_NEON
static inline void exeAbsInt8(int8_t* out, const int8_t* inp, int sizeQuad, int8x8_t inZeroPoint, int8x8_t outZeroPoint, float32x4_t inpScale, float32x4_t outScale) {
for (int i = 0;i < sizeQuad; ++i) {
int8x16_t absValue = vld1q_s8(inp);
int16x8_t val16_0 = vmovl_s8(vget_low_s8(absValue));
int16x8_t val16_1 = vmovl_s8(vget_high_s8(absValue));
val16_0 = vsubw_s8(val16_0, inZeroPoint);
val16_1 = vsubw_s8(val16_1, inZeroPoint);
int32x4_t val32_00 = vmovl_s16(vget_low_s16(val16_0));
int32x4_t val32_01 = vmovl_s16(vget_high_s16(val16_0));
int32x4_t val32_10 = vmovl_s16(vget_low_s16(val16_1));
int32x4_t val32_11 = vmovl_s16(vget_high_s16(val16_1));
float32x4_t valF_00 = vcvtq_f32_s32(val32_00);
float32x4_t valF_01 = vcvtq_f32_s32(val32_01);
float32x4_t valF_10 = vcvtq_f32_s32(val32_10);
float32x4_t valF_11 = vcvtq_f32_s32(val32_11);
valF_00 = vmulq_f32(valF_00, inpScale);
valF_01 = vmulq_f32(valF_01, inpScale);
valF_10 = vmulq_f32(valF_10, inpScale);
valF_11 = vmulq_f32(valF_11, inpScale);
valF_00 = vabsq_f32(valF_00);
valF_01 = vabsq_f32(valF_01);
valF_10 = vabsq_f32(valF_10);
valF_11 = vabsq_f32(valF_11);
valF_00 = vmulq_f32(valF_00, outScale);
valF_01 = vmulq_f32(valF_01, outScale);
valF_10 = vmulq_f32(valF_10, outScale);
valF_11 = vmulq_f32(valF_11, outScale);
int32x4_t val_00 = vcvtq_s32_f32(valF_00);
int32x4_t val_01 = vcvtq_s32_f32(valF_01);
int32x4_t val_10 = vcvtq_s32_f32(valF_10);
int32x4_t val_11 = vcvtq_s32_f32(valF_11);
int16x4_t v16_0 = vqmovn_s32(val_00);
int16x4_t v16_1 = vqmovn_s32(val_01);
int16x4_t v16_2 = vqmovn_s32(val_10);
int16x4_t v16_3 = vqmovn_s32(val_11);
int16x8_t v16_4 = vcombine_s16(v16_0, v16_1);
int16x8_t v16_5 = vcombine_s16(v16_2, v16_3);
v16_4 = vaddw_s8(v16_4, outZeroPoint);
v16_5 = vaddw_s8(v16_5, outZeroPoint);
int8x8_t v8_0 = vqmovn_s16(v16_4);
int8x8_t v8_1 = vqmovn_s16(v16_5);
vst1_s8(out, v8_0);
vst1_s8(out + 8, v8_1);
inp += 16;
out += 16;
}
}
#endif
static void _ABSInt8(void* out, const void* inp, int realSize, QuanPrePostParameters* params) { static void _ABSInt8(void* out, const void* inp, int realSize, QuanPrePostParameters* params) {
int sizeDiv16 = realSize / 16; int sizeDiv16 = realSize / 16;
int start = 0; int remain = realSize % 16;
#ifdef MNN_USE_NEON #ifdef MNN_USE_NEON
int8_t* outPtr = (int8_t*)out; int8_t* outPtr = (int8_t*)out;
int8_t* inPtr = (int8_t*)inp; int8_t* inPtr = (int8_t*)inp;
@ -147,55 +208,16 @@ static void _ABSInt8(void* out, const void* inp, int realSize, QuanPrePostParame
float32x4_t inpScale = vdupq_n_f32(params->inputScale[0]); float32x4_t inpScale = vdupq_n_f32(params->inputScale[0]);
float32x4_t outScale = vdupq_n_f32(params->outputScale[0]); float32x4_t outScale = vdupq_n_f32(params->outputScale[0]);
if (sizeDiv16 > 0) { if (sizeDiv16 > 0) {
for (int i = 0;i < sizeDiv16; ++i) { exeAbsInt8(outPtr, inPtr, sizeDiv16, inZeroPoint, outZeroPoint, inpScale, outScale);
int8x16_t absValue = vld1q_s8(inPtr);
int16x8_t val16_0 = vmovl_s8(vget_low_s8(absValue));
int16x8_t val16_1 = vmovl_s8(vget_high_s8(absValue));
val16_0 = vsubw_s8(val16_0, inZeroPoint);
val16_1 = vsubw_s8(val16_1, inZeroPoint);
int32x4_t val32_00 = vmovl_s16(vget_low_s16(val16_0));
int32x4_t val32_01 = vmovl_s16(vget_high_s16(val16_0));
int32x4_t val32_10 = vmovl_s16(vget_low_s16(val16_1));
int32x4_t val32_11 = vmovl_s16(vget_high_s16(val16_1));
float32x4_t valF_00 = vcvtq_f32_s32(val32_00);
float32x4_t valF_01 = vcvtq_f32_s32(val32_01);
float32x4_t valF_10 = vcvtq_f32_s32(val32_10);
float32x4_t valF_11 = vcvtq_f32_s32(val32_11);
valF_00 = vmulq_f32(valF_00, inpScale);
valF_01 = vmulq_f32(valF_01, inpScale);
valF_10 = vmulq_f32(valF_10, inpScale);
valF_11 = vmulq_f32(valF_11, inpScale);
valF_00 = vabsq_f32(valF_00);
valF_01 = vabsq_f32(valF_01);
valF_10 = vabsq_f32(valF_10);
valF_11 = vabsq_f32(valF_11);
valF_00 = vmulq_f32(valF_00, outScale);
valF_01 = vmulq_f32(valF_01, outScale);
valF_10 = vmulq_f32(valF_10, outScale);
valF_11 = vmulq_f32(valF_11, outScale);
int32x4_t val_00 = vcvtq_s32_f32(valF_00);
int32x4_t val_01 = vcvtq_s32_f32(valF_01);
int32x4_t val_10 = vcvtq_s32_f32(valF_10);
int32x4_t val_11 = vcvtq_s32_f32(valF_11);
int16x4_t v16_0 = vqmovn_s32(val_00);
int16x4_t v16_1 = vqmovn_s32(val_01);
int16x4_t v16_2 = vqmovn_s32(val_10);
int16x4_t v16_3 = vqmovn_s32(val_11);
int16x8_t v16_4 = vcombine_s16(v16_0, v16_1);
int16x8_t v16_5 = vcombine_s16(v16_2, v16_3);
v16_4 = vaddw_s8(v16_4, outZeroPoint);
v16_5 = vaddw_s8(v16_5, outZeroPoint);
int8x8_t v8_0 = vqmovn_s16(v16_4);
int8x8_t v8_1 = vqmovn_s16(v16_5);
vst1_s8(outPtr, v8_0);
vst1_s8(outPtr + 8, v8_1);
inPtr += 16;
outPtr += 16;
}
start = 16 * sizeDiv16;
} }
#endif if (remain > 0) {
int8_t intmp[16] = {0};
int8_t outmp[16] = {0};
::memcpy(intmp, reinterpret_cast<const int8_t*>(inp) + 16 * sizeDiv16, remain * sizeof(int8_t));
exeAbsInt8(outmp, intmp, 1, inZeroPoint, outZeroPoint, inpScale, outScale);
::memcpy(reinterpret_cast<int8_t*>(out) + 16 * sizeDiv16, outmp, remain * sizeof(int8_t));
}
#else
#ifdef MNN_USE_SSE #ifdef MNN_USE_SSE
uint8_t* dst = (uint8_t*)out; uint8_t* dst = (uint8_t*)out;
uint8_t* src = (uint8_t*)inp; uint8_t* src = (uint8_t*)inp;
@ -207,7 +229,7 @@ static void _ABSInt8(void* out, const void* inp, int realSize, QuanPrePostParame
#endif #endif
int inzero_ = static_cast<int>(params->inputZeroPoint[0]); int inzero_ = static_cast<int>(params->inputZeroPoint[0]);
int outzero_ = static_cast<int>(params->outputZeroPoint[0]); int outzero_ = static_cast<int>(params->outputZeroPoint[0]);
for (int i = start; i < realSize; ++i) { for (int i = 0; i < realSize; ++i) {
auto value = abs((src[i] - inzero_ - offset) * params->inputScale[0]); auto value = abs((src[i] - inzero_ - offset) * params->inputScale[0]);
value = value * params->outputScale[0] + outzero_; value = value * params->outputScale[0] + outzero_;
if (value > params->maxValue) { if (value > params->maxValue) {
@ -218,11 +240,38 @@ static void _ABSInt8(void* out, const void* inp, int realSize, QuanPrePostParame
} }
dst[i] = value + offset; dst[i] = value + offset;
} }
#endif
} }
#ifdef MNN_USE_NEON
static inline void exeSignInt8 (int8_t* out, const int8_t* inp, int sizeQuad, int16x8_t one, int16x8_t negone, int16x8_t zero, int8x8_t inZeroPoint, int8x8_t outZeroPoint, float32x4_t outScale) {
for (int i = 0;i < sizeQuad; ++i) {
int8x16_t value = vld1q_s8(inp);
int16x8_t vallow = vmovl_s8(vget_low_s8(value));
int16x8_t valhi = vmovl_s8(vget_high_s8(value));
vallow = vsubw_s8(vallow, inZeroPoint);
valhi = vsubw_s8(valhi, inZeroPoint);
uint16x8_t lomask1 = vcgtq_s16(vallow, zero);
uint16x8_t lomask_1 = vcltq_s16(vallow, zero);
uint16x8_t himask1 = vcgtq_s16(valhi, zero);
uint16x8_t himask_1 = vcltq_s16(valhi, zero);
uint16x8_t zeromask_low = vceqq_u16(lomask1, lomask_1);
uint16x8_t zeromask_hi = vceqq_u16(himask1, himask_1);
vallow = vbslq_s16(lomask1, one, negone);
vallow = vbslq_s16(zeromask_low, zero, vallow);
valhi = vbslq_s16(himask1, one, negone);
valhi = vbslq_s16(zeromask_hi, zero, valhi);
int8x8_t v8_0 = vqmovn_s16(vallow);
int8x8_t v8_1 = vqmovn_s16(valhi);
vst1_s8(out, v8_0);
vst1_s8(out + 8, v8_1);
inp += 16;
out += 16;
}
}
#endif
static void _SignInt8(void* out, const void* inp, int realSize, QuanPrePostParameters* params) { static void _SignInt8(void* out, const void* inp, int realSize, QuanPrePostParameters* params) {
int sizeDiv16 = realSize / 16; int sizeDiv16 = realSize / 16;
int start = 0; int remain = realSize % 16;
#ifdef MNN_USE_NEON #ifdef MNN_USE_NEON
int8_t* outPtr = (int8_t*)out; int8_t* outPtr = (int8_t*)out;
int8_t* inPtr = (int8_t*)inp; int8_t* inPtr = (int8_t*)inp;
@ -233,54 +282,16 @@ static void _SignInt8(void* out, const void* inp, int realSize, QuanPrePostParam
int8x8_t outZeroPoint = vdup_n_s8(params->outputZeroPoint[0]); int8x8_t outZeroPoint = vdup_n_s8(params->outputZeroPoint[0]);
float32x4_t outScale = vdupq_n_f32(params->outputScale[0]); float32x4_t outScale = vdupq_n_f32(params->outputScale[0]);
if (sizeDiv16 > 0) { if (sizeDiv16 > 0) {
for (int i = 0;i < sizeDiv16; ++i) { exeSignInt8(outPtr, inPtr, sizeDiv16, one, negone, zero, inZeroPoint, outZeroPoint, outScale);
int8x16_t value = vld1q_s8(inPtr);
int16x8_t vallow = vmovl_s8(vget_low_s8(value));
int16x8_t valhi = vmovl_s8(vget_high_s8(value));
vallow = vsubw_s8(vallow, inZeroPoint);
valhi = vsubw_s8(valhi, inZeroPoint);
uint16x8_t lomask1 = vcgtq_s16(vallow, zero);
uint16x8_t lomask_1 = vcltq_s16(vallow, zero);
uint16x8_t himask1 = vcgtq_s16(valhi, zero);
uint16x8_t himask_1 = vcltq_s16(valhi, zero);
vallow = vbslq_s16(lomask1, vallow, one);
vallow = vbslq_s16(lomask_1, vallow, negone);
valhi = vbslq_s16(himask1, valhi, one);
valhi = vbslq_s16(himask_1, valhi, negone);
int32x4_t val32_00 = vmovl_s16(vget_low_s16(vallow));
int32x4_t val32_01 = vmovl_s16(vget_high_s16(vallow));
int32x4_t val32_10 = vmovl_s16(vget_low_s16(valhi));
int32x4_t val32_11 = vmovl_s16(vget_high_s16(valhi));
float32x4_t valF_00 = vcvtq_f32_s32(val32_00);
float32x4_t valF_01 = vcvtq_f32_s32(val32_01);
float32x4_t valF_10 = vcvtq_f32_s32(val32_10);
float32x4_t valF_11 = vcvtq_f32_s32(val32_11);
valF_00 = vmulq_f32(valF_00, outScale);
valF_01 = vmulq_f32(valF_01, outScale);
valF_10 = vmulq_f32(valF_10, outScale);
valF_11 = vmulq_f32(valF_11, outScale);
int32x4_t val_00 = vcvtq_s32_f32(valF_00);
int32x4_t val_01 = vcvtq_s32_f32(valF_01);
int32x4_t val_10 = vcvtq_s32_f32(valF_10);
int32x4_t val_11 = vcvtq_s32_f32(valF_11);
int16x4_t v16_0 = vqmovn_s32(val_00);
int16x4_t v16_1 = vqmovn_s32(val_01);
int16x4_t v16_2 = vqmovn_s32(val_10);
int16x4_t v16_3 = vqmovn_s32(val_11);
int16x8_t v16_4 = vcombine_s16(v16_0, v16_1);
int16x8_t v16_5 = vcombine_s16(v16_2, v16_3);
v16_4 = vaddw_s8(v16_4, outZeroPoint);
v16_5 = vaddw_s8(v16_5, outZeroPoint);
int8x8_t v8_0 = vqmovn_s16(v16_4);
int8x8_t v8_1 = vqmovn_s16(v16_5);
vst1_s8(outPtr, v8_0);
vst1_s8(outPtr + 8, v8_1);
inPtr += 16;
outPtr += 16;
}
start = 16 * sizeDiv16;
} }
#endif if (remain > 0) {
int8_t intmp[16] = {0};
int8_t outmp[16] = {0};
::memcpy(intmp, reinterpret_cast<const int8_t*>(inp) + 16 * sizeDiv16, remain * sizeof(int8_t));
exeSignInt8(outmp, intmp, 1, one, negone, zero, inZeroPoint, outZeroPoint, outScale);
::memcpy(reinterpret_cast<int8_t*>(out) + 16 * sizeDiv16, outmp, remain * sizeof(int8_t));
}
#else
#ifdef MNN_USE_SSE #ifdef MNN_USE_SSE
uint8_t* dst = (uint8_t*)out; uint8_t* dst = (uint8_t*)out;
uint8_t* src = (uint8_t*)inp; uint8_t* src = (uint8_t*)inp;
@ -292,7 +303,7 @@ static void _SignInt8(void* out, const void* inp, int realSize, QuanPrePostParam
#endif #endif
int inzero_ = static_cast<int>(params->inputZeroPoint[0]); int inzero_ = static_cast<int>(params->inputZeroPoint[0]);
int outzero_ = static_cast<int>(params->outputZeroPoint[0]); int outzero_ = static_cast<int>(params->outputZeroPoint[0]);
for (int i = start; i < realSize; ++i) { for (int i = 0; i < realSize; ++i) {
auto value = src[i] - offset - inzero_; auto value = src[i] - offset - inzero_;
if (value > 0) { if (value > 0) {
int f = 1 * params->outputScale[0] + outzero_; int f = 1 * params->outputScale[0] + outzero_;
@ -304,6 +315,7 @@ static void _SignInt8(void* out, const void* inp, int realSize, QuanPrePostParam
dst[i] = outzero_ + offset; dst[i] = outzero_ + offset;
} }
} }
#endif
} }
static void _Square(void* out, const void* inp, int realSize) { static void _Square(void* out, const void* inp, int realSize) {

View File

@ -45,8 +45,8 @@ vdup.32 q10, r8 //q10: [28.f]x4
vdup.32 q9, r10 //q9: [3150.f]x4 vdup.32 q9, r10 //q9: [3150.f]x4
vdup.32 q8, r11 //q8: [62370.f]x4 vdup.32 q8, r11 //q8: [62370.f]x4
mov r4, #5.0 mov r4, #5
mov r5, #-5.0 mov r5, #-5
GeluZLoop: GeluZLoop:
@ -68,6 +68,8 @@ vmul.f32 q3, q3, q14 // value
// if value > 5, then value=5; if value<-5, then value=-5 // if value > 5, then value=5; if value<-5, then value=-5
vdup.32 q7, r4 vdup.32 q7, r4
vdup.32 q6, r5 vdup.32 q6, r5
vcvt.f32.s32 q7, q7
vcvt.f32.s32 q6, q6
vmax.f32 q2, q2, q6 vmax.f32 q2, q2, q6
vmax.f32 q3, q3, q6 vmax.f32 q3, q3, q6
vmin.f32 q2, q2, q7 vmin.f32 q2, q2, q7

View File

@ -45,8 +45,8 @@ vdup.32 q10, r8 //q10: [28.f]x4
vdup.32 q9, r10 //q9: [3150.f]x4 vdup.32 q9, r10 //q9: [3150.f]x4
vdup.32 q8, r11 //q8: [62370.f]x4 vdup.32 q8, r11 //q8: [62370.f]x4
mov r4, #5.0 mov r4, #5
mov r5, #-5.0 mov r5, #-5
GeluZLoop: GeluZLoop:
@ -70,6 +70,8 @@ vmul.f32 q3, q3, q14
vdup.32 q7, r4 vdup.32 q7, r4
vdup.32 q6, r5 vdup.32 q6, r5
vcvt.f32.s32 q7, q7
vcvt.f32.s32 q6, q6
vmax.f32 q2, q2, q6 vmax.f32 q2, q2, q6
vmax.f32 q3, q3, q6 vmax.f32 q3, q3, q6
vmin.f32 q2, q2, q7 vmin.f32 q2, q2, q7

View File

@ -2595,35 +2595,20 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i
} }
void MNNExp(float* dst, const float* src, const float* offset, size_t dataSize) { void MNNExp(float* dst, const float* src, const float* offset, size_t dataSize) {
int countC8 = (int)dataSize / 8; int countC8 = static_cast<int32_t>(dataSize) / 8;
int remain = static_cast<int32_t>(dataSize) % 8;
float parameters[] = {
(float)logf(2.0f), 1.0f / (float)logf(2.0f), 1.0f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
if (countC8 > 0) { if (countC8 > 0) {
// Align to eight so asm is easier to write // Align to eight so asm is easier to write
float parameters[] = {
(float)logf(2.0f), 1.0f / (float)logf(2.0f), 1.0f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
MNNExpC8(dst, src, offset, parameters, countC8); MNNExpC8(dst, src, offset, parameters, countC8);
} }
float alpha = offset[0]; if (remain > 0) {
float beta = offset[1]; float intmp[8] = {0};
int remain = countC8 * 8; float outmp[8] = {0};
auto param = logf(2.0f); ::memcpy(intmp, src + 8 * countC8, remain * sizeof(float));
float xLimit = 87; MNNExpC8(outmp, intmp, offset, parameters, 1);
for (int i = remain; i < dataSize; i++) { ::memcpy(dst + 8 * countC8, outmp, remain * sizeof(float));
/*Origin Function*/
//dst[i] = expf(src[i] * alpha) + beta;
/*Approciate Function*/
auto x = alpha * src[i];
x = ALIMAX(x, -xLimit);
x = ALIMIN(x, xLimit);
int div = (x / param);
int div2 = (div + 127) << 23;
auto xReamin = x - div * param;
float expBasic = *(float*)(&div2);
auto t = xReamin;
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
dst[i] = expBasic * expRemain + beta;
} }
} }
@ -2670,30 +2655,33 @@ void MNNReluWithSlope(float* dst, const float* src, size_t sizeQuad, float slope
} }
void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slope) { void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slope) {
int sizeQuad = size / 4; int sizeQuad = static_cast<int32_t>(size) / 4;
int start = 0; int remain = static_cast<int32_t>(size) % 4;
if (sizeQuad > 0) { if (sizeQuad > 0) {
MNNReluWithSlope(dst, src, sizeQuad, slope); MNNReluWithSlope(dst, src, sizeQuad, slope);
start = sizeQuad * 4;
} }
for (int j = start; j < size; j++) { if (remain > 0) {
if (src[j] < 0) { float intmp[4] = {0}, outmp[4] = {0};
dst[j] = src[j] * slope; ::memcpy(intmp, src + sizeQuad * 4, remain * sizeof(float));
} else { MNNReluWithSlope(outmp, intmp, 1, slope);
dst[j] = src[j]; ::memcpy(dst + sizeQuad * 4, outmp, remain * sizeof(float));
}
} }
} }
void MNNHardSwishCommon(float* dst, const float* src, size_t size) { void MNNHardSwishCommon(float* dst, const float* src, size_t size) {
int sizeQuad = static_cast<int32_t>(size / 4); int sizeQuad = static_cast<int32_t>(size / 4);
int start = 0; int remain = static_cast<int32_t>(size) % 4;
#ifdef MNN_USE_SSE #ifdef MNN_USE_SSE
if (sizeQuad > 0) { if (sizeQuad > 0) {
MNNHardSwish(dst, src, sizeQuad); MNNHardSwish(dst, src, sizeQuad);
start = sizeQuad * 4;
} }
#endif if (remain > 0) {
float intmp[4] = {0}, outmp[4] = {0};
::memcpy(intmp, src + sizeQuad * 4, remain * sizeof(float));
MNNHardSwish(outmp, intmp, 1);
::memcpy(dst + sizeQuad * 4, outmp, remain * sizeof(float));
}
#else
#ifdef MNN_USE_NEON #ifdef MNN_USE_NEON
float32x4_t zero = vdupq_n_f32(0.f); float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t three = vdupq_n_f32(3.f); float32x4_t three = vdupq_n_f32(3.f);
@ -2704,9 +2692,16 @@ void MNNHardSwishCommon(float* dst, const float* src, size_t size) {
auto y = vmulq_f32(vmulq_f32(x, vminq_f32(vmaxq_f32(vaddq_f32(x, three), zero), six)), divsix); auto y = vmulq_f32(vmulq_f32(x, vminq_f32(vmaxq_f32(vaddq_f32(x, three), zero), six)), divsix);
vst1q_f32(dst + 4 * i, y); vst1q_f32(dst + 4 * i, y);
} }
start = sizeQuad * 4; if (remain > 0) {
#endif float intmp[4] = {0}, outmp[4] = {0};
for (int j = start; j < size; j++) { ::memcpy(intmp, src + sizeQuad * 4, remain * sizeof(float));
auto x = vld1q_f32(intmp);
auto y = vmulq_f32(vmulq_f32(x, vminq_f32(vmaxq_f32(vaddq_f32(x, three), zero), six)), divsix);
vst1q_f32(outmp, y);
::memcpy(dst + sizeQuad * 4, outmp, remain * sizeof(float));
}
#else
for (int j = 0; j < size; j++) {
if (src[j] <= -3) { if (src[j] <= -3) {
dst[j] = 0; dst[j] = 0;
} else if (src[j] >= 3){ } else if (src[j] >= 3){
@ -2715,6 +2710,8 @@ void MNNHardSwishCommon(float* dst, const float* src, size_t size) {
dst[j] = src[j] * (src[j] + 3) / 6.f; dst[j] = src[j] * (src[j] + 3) / 6.f;
} }
} }
#endif
#endif
} }
void MNNGeluStandardCommon(float* dst, const float* src, size_t size) { void MNNGeluStandardCommon(float* dst, const float* src, size_t size) {
@ -2725,14 +2722,20 @@ void MNNGeluStandardCommon(float* dst, const float* src, size_t size) {
void MNNGeluCommon(float* dst, const float* src, size_t size) { void MNNGeluCommon(float* dst, const float* src, size_t size) {
int sizeQuad = static_cast<int32_t>(size / 8); int sizeQuad = static_cast<int32_t>(size / 8);
int start = 0; int remain = static_cast<int32_t>(size) % 8;
#if defined(MNN_USE_SSE) || defined(MNN_USE_NEON) #if defined(MNN_USE_SSE) || defined(MNN_USE_NEON)
float parameters[8] = {0.044715f, 0.79788458f, 378.f, 17325.f, 135135.f, 28.f, 3150.f, 62370.f};
if (sizeQuad > 0) { if (sizeQuad > 0) {
float parameters[8] = {0.044715f, 0.79788458f, 378.f, 17325.f, 135135.f, 28.f, 3150.f, 62370.f};
MNNGelu(dst, src, sizeQuad, parameters); MNNGelu(dst, src, sizeQuad, parameters);
start = sizeQuad * 8;
} }
#endif if (remain > 0) {
float intmp[8] = {0};
float outmp[8] = {0};
::memcpy(intmp, src + 8 * sizeQuad, remain * sizeof(float));
MNNGelu(outmp, intmp, 1, parameters);
::memcpy(dst + 8 * sizeQuad, outmp, remain * sizeof(float));
}
#else
auto tanhf_poly = [](float value) -> float { auto tanhf_poly = [](float value) -> float {
if (value > 5.0f) { if (value > 5.0f) {
return 1.0f; return 1.0f;
@ -2745,11 +2748,12 @@ void MNNGeluCommon(float* dst, const float* src, size_t size) {
return a / b; return a / b;
} }
}; };
for (int i = start; i < size; i++) { for (int i = 0; i < size; i++) {
float temp = 0.044715f * src[i] * src[i] * src[i]; float temp = 0.044715f * src[i] * src[i] * src[i];
temp = 0.79788458f * (temp + src[i]); temp = 0.79788458f * (temp + src[i]);
dst[i] = (1.0f + tanhf_poly(temp)) * src[i] * 0.5f; dst[i] = (1.0f + tanhf_poly(temp)) * src[i] * 0.5f;
} }
#endif
} }
void MNNScaleAndAddBiasScalar(float* dst, const float* src, float bias, float alpha, size_t number) { void MNNScaleAndAddBiasScalar(float* dst, const float* src, float bias, float alpha, size_t number) {
@ -3056,11 +3060,13 @@ void MNNSigmoidLowp(float* dst, const float* src, size_t dataSize) {
}; };
MNNExp(dst, src, offset, dataSize); MNNExp(dst, src, offset, dataSize);
#ifdef MNN_USE_NEON #ifdef MNN_USE_NEON
int dataC4 = (int)dataSize / 4; int dataC4 = static_cast<int32_t>(dataSize) / 4;
int remain = static_cast<int32_t>(dataSize) % 4;
float32x4_t value = vdupq_n_f32(1.0f);
if(dataC4 > 0) { if(dataC4 > 0) {
// neon optimization for sigmid cpu
float32x4_t value = vdupq_n_f32(1.0f);
float32x4_t out = vld1q_f32(dst); float32x4_t out = vld1q_f32(dst);
// neon optimization for sigmid cpu
for (int i = 1; i < dataC4; ++i) { for (int i = 1; i < dataC4; ++i) {
out = vrecpeq_f32(vaddq_f32(value,out)); out = vrecpeq_f32(vaddq_f32(value,out));
vst1q_f32(dst ,out); vst1q_f32(dst ,out);
@ -3070,12 +3076,20 @@ void MNNSigmoidLowp(float* dst, const float* src, size_t dataSize) {
out = vrecpeq_f32(vaddq_f32(value,out)); out = vrecpeq_f32(vaddq_f32(value,out));
vst1q_f32(dst, out); vst1q_f32(dst, out);
dst += 4; dst += 4;
dataSize = dataSize - 4 * dataC4;
} }
#endif if (remain > 0) {
float intmp[4] = {0};
::memcpy(intmp, dst, remain * sizeof(float));
float32x4_t out = vld1q_f32(intmp);
out = vrecpeq_f32(vaddq_f32(value,out));
vst1q_f32(intmp, out);
::memcpy(dst, intmp, remain * sizeof(float));
}
#else
for (int i = 0; i < dataSize; ++i) { for (int i = 0; i < dataSize; ++i) {
dst[i] = 1.0f / (1.0f + dst[i]); dst[i] = 1.0f / (1.0f + dst[i]);
} }
#endif
} }
void MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameters) { void MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameters) {

View File

@ -231,6 +231,10 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
} else { } else {
quanParam.minValue = mMutableResource.mClampMin; quanParam.minValue = mMutableResource.mClampMin;
} }
int dstBytes = static_cast<CPUBackend*>(backend())->getBytes(backend(), output);
if (dstBytes != 1) {
quanParam.useInt8 = 0;
}
//MNN_PRINT("max: %d, min: %d\n", quanParam.maxValue, quanParam.minValue); //MNN_PRINT("max: %d, min: %d\n", quanParam.maxValue, quanParam.minValue);
const int col_buffer_unit_size = mIm2ColParamter.kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t); const int col_buffer_unit_size = mIm2ColParamter.kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t);
auto col_buffer_size = col_buffer_unit_size * mIm2ColCount; auto col_buffer_size = col_buffer_unit_size * mIm2ColCount;
@ -262,13 +266,13 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
if (number > 0) { if (number > 0) {
blitProc(colAddr, srcPtr, info, el); blitProc(colAddr, srcPtr, info, el);
} }
auto outputInTilePtr = outputDataPtr + xIndexStart * PackUnit; auto outputInTilePtr = outputDataPtr + xIndexStart * PackUnit * dstBytes;
auto colAddrTemp = colAddr; auto colAddrTemp = colAddr;
do { do {
int step = ALIMIN(DST_XUNIT, realDstCount); int step = ALIMIN(DST_XUNIT, realDstCount);
mGemmKernel(outputInTilePtr, colAddrTemp, weightDataPtr, kernelCountUnitDouble, dstZStep, ocDiv4, &quanParam, step); mGemmKernel(outputInTilePtr, colAddrTemp, weightDataPtr, kernelCountUnitDouble, dstZStep * dstBytes, ocDiv4, &quanParam, step);
realDstCount-=step; realDstCount-=step;
outputInTilePtr += DST_XUNIT * PackUnit; outputInTilePtr += DST_XUNIT * PackUnit * dstBytes;
colAddrTemp += col_buffer_unit_size; colAddrTemp += col_buffer_unit_size;
} while(realDstCount > 0); } while(realDstCount > 0);
} }

View File

@ -110,9 +110,6 @@ ErrorCode Convolution1x1Strassen::onResize(const std::vector<Tensor *> &inputs,
auto matrixSizeE = output->height() * output->width() * input->batch(); auto matrixSizeE = output->height() * output->width() * input->batch();
auto outputPlane = output->height() * output->width(); auto outputPlane = output->height() * output->width();
mUnits.clear(); mUnits.clear();
auto inputPtr = TensorUtils::getDescribe(input)->mem->chunk();
auto outputPtr = TensorUtils::getDescribe(output)->mem->chunk();
std::shared_ptr<char> __autoFunction; std::shared_ptr<char> __autoFunction;
auto padY = mPadY; auto padY = mPadY;
auto padX = mPadX; auto padX = mPadX;
@ -156,9 +153,9 @@ ErrorCode Convolution1x1Strassen::onResize(const std::vector<Tensor *> &inputs,
int e = planeSize; int e = planeSize;
int l = ic; int l = ic;
int h = oc; int h = oc;
auto aPtr = inputPtr + core->pack * planeStart * bytes; uint8_t* aPtr = nullptr;
auto bPtr = TensorUtils::getDescribe(weightTensor)->mem->chunk();; auto bPtr = TensorUtils::getDescribe(weightTensor)->mem->chunk();;
auto cPtr = outputPtr + core->pack * planeStart * bytes; uint8_t* cPtr = nullptr;
auto biasPtr = TensorUtils::getDescribe(mResource->mBias.get())->mem->chunk(); auto biasPtr = TensorUtils::getDescribe(mResource->mBias.get())->mem->chunk();
memoryPool->beginGroup(); memoryPool->beginGroup();
auto code = unit.mStracssenComputor->onEncode(e, l, h, matrixSizeE * core->pack, UP_DIV(l, lPack) * lPack * hPack, matrixSizeE * core->pack, aPtr, bPtr, cPtr, true, biasPtr, postParameters); auto code = unit.mStracssenComputor->onEncode(e, l, h, matrixSizeE * core->pack, UP_DIV(l, lPack) * lPack * hPack, matrixSizeE * core->pack, aPtr, bPtr, cPtr, true, biasPtr, postParameters);
@ -200,9 +197,9 @@ ErrorCode Convolution1x1Strassen::onResize(const std::vector<Tensor *> &inputs,
int e = matrixSizeE; int e = matrixSizeE;
int l = ic; int l = ic;
int h = std::min(ocSize * core->pack, ocWeightSize * hPack); int h = std::min(ocSize * core->pack, ocWeightSize * hPack);
auto aPtr = inputPtr; uint8_t* aPtr = nullptr;
auto bPtr = TensorUtils::getDescribe(mResource->mWeight.get())->mem->chunk() + hPack * icAlign * ocStartWeight * bytes; auto bPtr = TensorUtils::getDescribe(mResource->mWeight.get())->mem->chunk() + hPack * icAlign * ocStartWeight * bytes;
auto cPtr = outputPtr + core->pack * matrixSizeE * ocStart * bytes; uint8_t* cPtr = nullptr;
auto biasPtr = TensorUtils::getDescribe(mResource->mBias.get())->mem->chunk() + core->pack * ocStart * bytes; auto biasPtr = TensorUtils::getDescribe(mResource->mBias.get())->mem->chunk() + core->pack * ocStart * bytes;
memoryPool->beginGroup(); memoryPool->beginGroup();
auto code = unit.mStracssenComputor->onEncode(e, l, h, matrixSizeE * core->pack, UP_DIV(l, lPack) * lPack * hPack, matrixSizeE * core->pack, aPtr, bPtr, cPtr, true, biasPtr, postParameters); auto code = unit.mStracssenComputor->onEncode(e, l, h, matrixSizeE * core->pack, UP_DIV(l, lPack) * lPack * hPack, matrixSizeE * core->pack, aPtr, bPtr, cPtr, true, biasPtr, postParameters);

View File

@ -1453,7 +1453,9 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co
for (int j = 0; j < GEMM_INT8_UNIT; ++j) { for (int j = 0; j < GEMM_INT8_UNIT; ++j) {
const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT; const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT;
for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) { for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) {
// if (j == 2) printf("%d, %d\n", (int32_t)src_z[i], (int32_t)weight_j[i]);
dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i]; dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i];
// if (j == 0) printf("%d\n", dstTemp[j]);
} }
} }
} }

View File

@ -6,7 +6,6 @@ if(MNN_CUDA_PROFILE)
set(EXTRA_LIBS -lnvToolsExt) set(EXTRA_LIBS -lnvToolsExt)
endif() endif()
if(CUDA_FOUND) if(CUDA_FOUND)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -D_FORCE_INLINES -Wno-deprecated-gpu-targets -w ${EXTRA_LIBS}") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -D_FORCE_INLINES -Wno-deprecated-gpu-targets -w ${EXTRA_LIBS}")
if(CMAKE_BUILD_TYPE MATCHES Debug) if(CMAKE_BUILD_TYPE MATCHES Debug)
@ -52,6 +51,7 @@ if(CUDA_FOUND)
ENDIF() ENDIF()
# Limit minimum cuda version for each archs # Limit minimum cuda version for each archs
IF (${arch_count} EQUAL 1) IF (${arch_count} EQUAL 1)
IF ((CUDA_ARCH_FLAGS_readable_code VERSION_GREATER "80") OR (CUDA_ARCH_FLAGS_readable_code VERSION_EQUAL "80")) IF ((CUDA_ARCH_FLAGS_readable_code VERSION_GREATER "80") OR (CUDA_ARCH_FLAGS_readable_code VERSION_EQUAL "80"))
IF (CUDA_VERSION VERSION_LESS "11.2") IF (CUDA_VERSION VERSION_LESS "11.2")

View File

@ -52,7 +52,8 @@ ErrorCode BinaryExecution::onExecute(const std::vector<Tensor *> &inputs, const
int stride1[3] = {0, 0, s1}; int stride1[3] = {0, 0, s1};
int stride2[3] = {0, 0, 1}; int stride2[3] = {0, 0, 1};
auto type = outputs[0]->getType(); // use input type. output type maybe fixed, for example greater/less
auto type = inputs[0]->getType();
if (type.code == halide_type_float) { if (type.code == halide_type_float) {
// Use Half or float // Use Half or float
type.bits = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]) * 8; type.bits = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]) * 8;

View File

@ -199,54 +199,55 @@ ErrorCode ConvCutlassExecution::onResize(const std::vector<Tensor*> &inputs, con
mGpuComputeCap = runtime->compute_capability(); mGpuComputeCap = runtime->compute_capability();
//MNN_PRINT("Gpu smArch is sm_%d\n", mGpuComputeCap); //MNN_PRINT("Gpu smArch is sm_%d\n", mGpuComputeCap);
if(mGpuComputeCap < 70) { if (mGpuComputeCap < 70) {
return callCutlassGemmCudaCoreFloat16(inputs, outputs); return callCutlassGemmCudaCoreFloat16(inputs, outputs);
} else if(mGpuComputeCap < 75) { } else if (mGpuComputeCap < 75) {
return callCutlassGemmTensorCore884(inputs, outputs); return callCutlassGemmTensorCore884(inputs, outputs);
} }
#ifdef ENABLE_CUDA_TUNE_PARAM
if (mGpuComputeCap >= 80) {
mIsTuned = true;
/*
// 0 -> Gemm, 1~N -> BatchGemm
int32_t batchSize = 0;
// [0]->A, [1]->B, [2]->bias, [3]->output
std::pair<void *, int32_t> ptrOffset[4];
int32_t batchOffset[4];
// [0]->alpha, [1]->beta, [2]->splitK
int32_t coefs[3];
// 0 -> RowColumn, 1 -> RowRow
int32_t layout;
bool epilogueVectorize
*/
mInfo.problemSize[0] = mGemmInfo.elh[0];
mInfo.problemSize[1] = mGemmInfo.elhPad[2];
mInfo.problemSize[2] = mGemmInfo.elhPad[1];
#ifdef ENABLE_CUDA_TUNE_PARAM mInfo.coefs[0] = 1;
/* mInfo.coefs[1] = 1;
// 0 -> Gemm, 1~N -> BatchGemm mInfo.coefs[2] = 1;
int32_t batchSize = 0;
// [0]->A, [1]->B, [2]->bias, [3]->output
std::pair<void *, int32_t> ptrOffset[4];
int32_t batchOffset[4];
// [0]->alpha, [1]->beta, [2]->splitK
int32_t coefs[3];
// 0 -> RowColumn, 1 -> RowRow
int32_t layout;
bool epilogueVectorize
*/
mInfo.problemSize[0] = mGemmInfo.elh[0];
mInfo.problemSize[1] = mGemmInfo.elhPad[2];
mInfo.problemSize[2] = mGemmInfo.elhPad[1];
mInfo.coefs[0] = 1; mInfo.epilogueVectorize = true;
mInfo.coefs[1] = 1; mInfo.epilogueType = mActivationType;// Linear-Relu-Relu6
mInfo.coefs[2] = 1; mInfo.precisionType = mPrecisonLevel;//
mInfo.backend = mBackendPtr;
mInfo.epilogueVectorize = true; mInfo.batchSize = 0;// For Gemm
mInfo.epilogueType = mActivationType;// Linear-Relu-Relu6 mInfo.layout = 0;
mInfo.precisionType = mPrecisonLevel;// void *inputA_ptr = mNeedIm2Col ? (void *)mIm2ColBuffer : (void *)input->deviceId();
mInfo.backend = mBackendPtr;
mInfo.batchSize = 0;// For Gemm mInfo.ptrOffset[0] = std::make_pair((void *)inputA_ptr, mGemmInfo.elhPad[1]);
mInfo.layout = 0; mInfo.ptrOffset[1] = std::make_pair((void *)mFilterAddr, mGemmInfo.elhPad[1]);
void *inputA_ptr = mNeedIm2Col ? (void *)mIm2ColBuffer : (void *)input->deviceId(); mInfo.ptrOffset[2] = std::make_pair((void *)mBiasAddr, 0);
mInfo.ptrOffset[3] = std::make_pair((void *)outputs[0]->deviceId(), mGemmInfo.elhPad[2]);
mInfo.ptrOffset[0] = std::make_pair((void *)inputA_ptr, mGemmInfo.elhPad[1]); getGemmTensorCoreFloat16Param(&mInfo);
mInfo.ptrOffset[1] = std::make_pair((void *)mFilterAddr, mGemmInfo.elhPad[1]); // set preferd block shape argments
mInfo.ptrOffset[2] = std::make_pair((void *)mBiasAddr, 0); setGemmTensorCoreFloat16Argments(&mInfo);
mInfo.ptrOffset[3] = std::make_pair((void *)outputs[0]->deviceId(), mGemmInfo.elhPad[2]); return NO_ERROR;
getGemmTensorCoreFloat16Param(&mInfo); }
// set preferd block shape argments #endif
setGemmTensorCoreFloat16Argments(&mInfo);
return NO_ERROR;
#else
return callCutlassGemmTensorCore(inputs, outputs); return callCutlassGemmTensorCore(inputs, outputs);
#endif
} }
ErrorCode ConvCutlassExecution::onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) { ErrorCode ConvCutlassExecution::onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {

View File

@ -310,75 +310,78 @@ ErrorCode ConvWinogradExecution::onResize(const std::vector<Tensor*> &inputs, c
//MNN_PRINT("Winograd BatchGemm batch:%d, MNK:%d-%d-%d\n", mBlock2, mGemmInfo.elh[0], mGemmInfo.elhPad[2], mGemmInfo.elhPad[1]); //MNN_PRINT("Winograd BatchGemm batch:%d, MNK:%d-%d-%d\n", mBlock2, mGemmInfo.elh[0], mGemmInfo.elhPad[2], mGemmInfo.elhPad[1]);
if(mFp16Infer) { if(mFp16Infer) {
#ifdef ENABLE_CUDA_TUNE_PARAM #ifdef ENABLE_CUDA_TUNE_PARAM
/* if(mGpuComputeCap >= 80 ) {
// 0 -> Gemm, 1~N -> BatchGemm mIsTuned = true;
int32_t batchSize = 0; /*
// [0]->A, [1]->B, [2]->bias, [3]->output // 0 -> Gemm, 1~N -> BatchGemm
std::pair<void *, int32_t> ptrOffset[4]; int32_t batchSize = 0;
int32_t batchOffset[4]; // [0]->A, [1]->B, [2]->bias, [3]->output
// [0]->alpha, [1]->beta, [2]->splitK std::pair<void *, int32_t> ptrOffset[4];
int32_t coefs[3]; int32_t batchOffset[4];
// 0 -> RowColumn, 1 -> RowRow // [0]->alpha, [1]->beta, [2]->splitK
int32_t layout; int32_t coefs[3];
bool epilogueVectorize // 0 -> RowColumn, 1 -> RowRow
*/ int32_t layout;
mInfo.problemSize[0] = mGemmInfo.elh[0]; bool epilogueVectorize
mInfo.problemSize[1] = mGemmInfo.elhPad[2]; */
mInfo.problemSize[2] = mGemmInfo.elhPad[1]; mInfo.problemSize[0] = mGemmInfo.elh[0];
mInfo.problemSize[1] = mGemmInfo.elhPad[2];
mInfo.problemSize[2] = mGemmInfo.elhPad[1];
mInfo.coefs[0] = 1; mInfo.coefs[0] = 1;
mInfo.coefs[1] = 0; mInfo.coefs[1] = 0;
mInfo.epilogueVectorize = true; mInfo.epilogueVectorize = true;
mInfo.epilogueType = 0;// Linear mInfo.epilogueType = 0;// Linear
mInfo.precisionType = 2;// FP16_FP16 mInfo.precisionType = 2;// FP16_FP16
mInfo.backend = mResource->mBackend; mInfo.backend = mResource->mBackend;
mInfo.batchSize = mBlock2; mInfo.batchSize = mBlock2;
mInfo.layout = 0; mInfo.layout = 0;
mInfo.ptrOffset[0] = std::make_pair((void *)mBtdB_Buffer, mGemmInfo.elhPad[1]); mInfo.ptrOffset[0] = std::make_pair((void *)mBtdB_Buffer, mGemmInfo.elhPad[1]);
mInfo.ptrOffset[1] = std::make_pair((void *)mResource->mFilter, mGemmInfo.elhPad[1]); mInfo.ptrOffset[1] = std::make_pair((void *)mResource->mFilter, mGemmInfo.elhPad[1]);
mInfo.ptrOffset[2] = std::make_pair((void *)mResource->mBias, 0); mInfo.ptrOffset[2] = std::make_pair((void *)mResource->mBias, 0);
mInfo.ptrOffset[3] = std::make_pair((void *)mMatmul_Buffer, mGemmInfo.elhPad[2]); mInfo.ptrOffset[3] = std::make_pair((void *)mMatmul_Buffer, mGemmInfo.elhPad[2]);
mInfo.batchOffset[0] = mGemmInfo.elh[0] * mGemmInfo.elhPad[1]; mInfo.batchOffset[0] = mGemmInfo.elh[0] * mGemmInfo.elhPad[1];
mInfo.batchOffset[1] = mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]; mInfo.batchOffset[1] = mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2];
mInfo.batchOffset[2] = 0; mInfo.batchOffset[2] = 0;
mInfo.batchOffset[3] = mGemmInfo.elh[0] * mGemmInfo.elhPad[2]; mInfo.batchOffset[3] = mGemmInfo.elh[0] * mGemmInfo.elhPad[2];
getGemmBatchedTensorCoreFloat16Param(&mInfo); getGemmBatchedTensorCoreFloat16Param(&mInfo);
// set preferd block shape argments // set preferd block shape argments
setGemmBatchedTensorCoreFloat16Argments(&mInfo); setGemmBatchedTensorCoreFloat16Argments(&mInfo);
#else
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mBtdB_Buffer, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]), // batch_stride_A
{(ElementInput_F16 *)mResource->mFilter, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]), // batch_stride_B
{(ElementOutput_F16 *)mResource->mBias, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)mMatmul_Buffer, mGemmInfo.elhPad[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBlock2}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mResource->mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
} }
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
#endif #endif
if(!mIsTuned) {
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mBtdB_Buffer, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]), // batch_stride_A
{(ElementInput_F16 *)mResource->mFilter, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]), // batch_stride_B
{(ElementOutput_F16 *)mResource->mBias, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)mMatmul_Buffer, mGemmInfo.elhPad[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBlock2}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mResource->mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
}
} else { } else {
typename GemmBatchedTensor_F16_F32_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication typename GemmBatchedTensor_F16_F32_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
@ -477,12 +480,15 @@ ErrorCode ConvWinogradExecution::onExecute(const std::vector<Tensor*> &inputs, c
cutlass::Status status = mGemmBatchedF16F32LnSm75(); cutlass::Status status = mGemmBatchedF16F32LnSm75();
cutlass_check(status); cutlass_check(status);
} else { } else {
#ifdef ENABLE_CUDA_TUNE_PARAM #ifdef ENABLE_CUDA_TUNE_PARAM
runGemmBatchedTensorCoreFloat16Infer(&mInfo); if(mIsTuned) {
#else runGemmBatchedTensorCoreFloat16Infer(&mInfo);
cutlass::Status status = mGemmBatchedF16F16LnSm75(); }
cutlass_check(status); #endif
#endif if(!mIsTuned) {
cutlass::Status status = mGemmBatchedF16F16LnSm75();
cutlass_check(status);
}
} }
} }
} }

View File

@ -71,6 +71,7 @@ private:
int mPadY; int mPadY;
int mBlock2; int mBlock2;
int mGpuComputeCap; int mGpuComputeCap;
bool mIsTuned =false;
int mActivationType; int mActivationType;
bool mFp16Infer = false; bool mFp16Infer = false;
bool mFp32Infer = false; bool mFp32Infer = false;

View File

@ -481,215 +481,219 @@ void MatMulExecution::setArguments(const std::vector<Tensor *> &inputs, const st
if(mFp16Infer) { if(mFp16Infer) {
#ifdef ENABLE_CUDA_TUNE_PARAM #ifdef ENABLE_CUDA_TUNE_PARAM
/* if(mGpuComputeCap >= 80) {
// 0 -> Gemm, 1~N -> BatchGemm mIsTuned = true;
int32_t batchSize = 0; /*
// [0]->A, [1]->B, [2]->bias, [3]->output // 0 -> Gemm, 1~N -> BatchGemm
std::pair<void *, int32_t> ptrOffset[4]; int32_t batchSize = 0;
int32_t batchOffset[4]; // [0]->A, [1]->B, [2]->bias, [3]->output
// [0]->alpha, [1]->beta, [2]->splitK std::pair<void *, int32_t> ptrOffset[4];
int32_t coefs[3]; int32_t batchOffset[4];
// 0 -> RowColumn, 1 -> RowRow // [0]->alpha, [1]->beta, [2]->splitK
int32_t layout; int32_t coefs[3];
bool epilogueVectorize // 0 -> RowColumn, 1 -> RowRow
*/ int32_t layout;
mInfo.problemSize[0] = mGemmInfo.elh[0]; bool epilogueVectorize
mInfo.problemSize[1] = mGemmInfo.elh[2]; */
mInfo.problemSize[2] = mGemmInfo.elhPad[1]; mInfo.problemSize[0] = mGemmInfo.elh[0];
mInfo.problemSize[1] = mGemmInfo.elh[2];
mInfo.problemSize[2] = mGemmInfo.elhPad[1];
mInfo.coefs[0] = 1; mInfo.coefs[0] = 1;
mInfo.coefs[1] = 0; mInfo.coefs[1] = 0;
if (inputs.size() > 2) { if (inputs.size() > 2) {
mInfo.coefs[1] = 1; mInfo.coefs[1] = 1;
}
mInfo.epilogueVectorize = true;
mInfo.epilogueType = 0;// Linear
mInfo.precisionType = 2;// FP16_FP16
mInfo.backend = mBackend;
if(mUseRRLayout) {
mInfo.batchSize = mBatch;
mInfo.layout = 1;
mInfo.ptrOffset[0] = std::make_pair((void *)mTempMatA, mGemmInfo.elhPad[1]);
mInfo.ptrOffset[1] = std::make_pair((void *)mTempMatB, mGemmInfo.elhPad[2]);
mInfo.ptrOffset[2] = std::make_pair((void *)mBiasPtr, 0);
mInfo.ptrOffset[3] = std::make_pair((void *)C->deviceId(), mGemmInfo.elhPad[2]);
mInfo.batchOffset[0] = mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs;
mInfo.batchOffset[1] = mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs;
mInfo.batchOffset[2] = 0;
mInfo.batchOffset[3] = mGemmInfo.elh[0] * mGemmInfo.elhPad[2];
} else {
if(hAlignment) {
mInfo.epilogueVectorize = true;
} else {
mInfo.epilogueVectorize = false;
} }
mInfo.epilogueVectorize = true;
mInfo.epilogueType = 0;// Linear
mInfo.precisionType = 2;// FP16_FP16
mInfo.backend = mBackend;
if(hAlignment && mConvertGemmSplitK) { if(mUseRRLayout) {
mInfo.batchSize = 0;
mInfo.layout = 0;
mInfo.coefs[2] = 16;
mInfo.ptrOffset[0] = std::make_pair((void *)mTempMatA, mGemmInfo.elhPad[1]);
mInfo.ptrOffset[1] = std::make_pair((void *)mTempMatB, mGemmInfo.elhPad[1]);
mInfo.ptrOffset[2] = std::make_pair((void *)mBiasPtr, 0);
mInfo.ptrOffset[3] = std::make_pair((void *)C->deviceId(), mGemmInfo.elh[2]);
} else {
mInfo.batchSize = mBatch; mInfo.batchSize = mBatch;
mInfo.layout = 0; mInfo.layout = 1;
mInfo.ptrOffset[0] = std::make_pair((void *)mTempMatA, mGemmInfo.elhPad[1]); mInfo.ptrOffset[0] = std::make_pair((void *)mTempMatA, mGemmInfo.elhPad[1]);
mInfo.ptrOffset[1] = std::make_pair((void *)mTempMatB, mGemmInfo.elhPad[1]); mInfo.ptrOffset[1] = std::make_pair((void *)mTempMatB, mGemmInfo.elhPad[2]);
mInfo.ptrOffset[2] = std::make_pair((void *)mBiasPtr, 0); mInfo.ptrOffset[2] = std::make_pair((void *)mBiasPtr, 0);
mInfo.ptrOffset[3] = std::make_pair((void *)C->deviceId(), mGemmInfo.elh[2]); mInfo.ptrOffset[3] = std::make_pair((void *)C->deviceId(), mGemmInfo.elhPad[2]);
mInfo.batchOffset[0] = mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs; mInfo.batchOffset[0] = mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs;
mInfo.batchOffset[1] = mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs; mInfo.batchOffset[1] = mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs;
mInfo.batchOffset[2] = 0; mInfo.batchOffset[2] = 0;
mInfo.batchOffset[3] = mGemmInfo.elh[0] * mGemmInfo.elh[2]; mInfo.batchOffset[3] = mGemmInfo.elh[0] * mGemmInfo.elhPad[2];
}
}
getGemmBatchedTensorCoreFloat16Param(&mInfo);
// set preferd block shape argments
setGemmBatchedTensorCoreFloat16Argments(&mInfo);
#else
if(mUseRRLayout) {
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Row_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBatch}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Row_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnAlign8RRSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnAlign8RRSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
} else {
if(hAlignment) {
if(mConvertGemmSplitK) {
int split_k_slices = 16;
typename GemmTensor_F16_F16_Linear_AlignTensor_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
size_t workspace_size = GemmTensor_F16_F16_Linear_AlignTensor_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
cutlass::Status status = mGemmF16F16LnAlign8Sm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmF16F16LnAlign8Sm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
} else {
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBatch}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnAlign8RCSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnAlign8RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
}
} else { } else {
if(mConvertGemmSplitK) { if(hAlignment) {
int split_k_slices = 16; mInfo.epilogueVectorize = true;
typename GemmTensor_F16_F16_Linear_AlignCuda_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
size_t workspace_size = GemmTensor_F16_F16_Linear_AlignCuda_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
cutlass::Status status = mGemmF16F16LnAlign1Sm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmF16F16LnAlign1Sm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
} else { } else {
typename GemmBatchedTensor_F16_F16_Linear_AlignCuda_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication mInfo.epilogueVectorize = false;
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm }
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBatch}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignCuda_Row_Column_Sm75::get_workspace_size(arguments); if(hAlignment && mConvertGemmSplitK) {
mInfo.batchSize = 0;
mInfo.layout = 0;
mInfo.coefs[2] = 16;
if(workspace_size != 0) { mInfo.ptrOffset[0] = std::make_pair((void *)mTempMatA, mGemmInfo.elhPad[1]);
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size})); mInfo.ptrOffset[1] = std::make_pair((void *)mTempMatB, mGemmInfo.elhPad[1]);
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC); mInfo.ptrOffset[2] = std::make_pair((void *)mBiasPtr, 0);
mWorkspace = (void *)workspaceTensor.get()->buffer().device; mInfo.ptrOffset[3] = std::make_pair((void *)C->deviceId(), mGemmInfo.elh[2]);
} } else {
// Check the problem size is supported or not mInfo.batchSize = mBatch;
cutlass::Status status = mGemmBatchedF16F16LnAlign1RCSm75.can_implement(arguments); mInfo.layout = 0;
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer mInfo.ptrOffset[0] = std::make_pair((void *)mTempMatA, mGemmInfo.elhPad[1]);
status = mGemmBatchedF16F16LnAlign1RCSm75.initialize(arguments, (uint8_t *)mWorkspace); mInfo.ptrOffset[1] = std::make_pair((void *)mTempMatB, mGemmInfo.elhPad[1]);
cutlass_check(status); mInfo.ptrOffset[2] = std::make_pair((void *)mBiasPtr, 0);
mInfo.ptrOffset[3] = std::make_pair((void *)C->deviceId(), mGemmInfo.elh[2]);
mInfo.batchOffset[0] = mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs;
mInfo.batchOffset[1] = mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs;
mInfo.batchOffset[2] = 0;
mInfo.batchOffset[3] = mGemmInfo.elh[0] * mGemmInfo.elh[2];
} }
} }
getGemmBatchedTensorCoreFloat16Param(&mInfo);
// set preferd block shape argments
setGemmBatchedTensorCoreFloat16Argments(&mInfo);
} }
#endif #endif
if(!mIsTuned) {
if(mUseRRLayout) {
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Row_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBatch}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Row_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnAlign8RRSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnAlign8RRSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
} else {
if(hAlignment) {
if(mConvertGemmSplitK) {
int split_k_slices = 16;
typename GemmTensor_F16_F16_Linear_AlignTensor_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
size_t workspace_size = GemmTensor_F16_F16_Linear_AlignTensor_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
cutlass::Status status = mGemmF16F16LnAlign8Sm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmF16F16LnAlign8Sm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
} else {
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBatch}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnAlign8RCSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnAlign8RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
}
} else {
if(mConvertGemmSplitK) {
int split_k_slices = 16;
typename GemmTensor_F16_F16_Linear_AlignCuda_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
size_t workspace_size = GemmTensor_F16_F16_Linear_AlignCuda_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
cutlass::Status status = mGemmF16F16LnAlign1Sm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmF16F16LnAlign1Sm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
} else {
typename GemmBatchedTensor_F16_F16_Linear_AlignCuda_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
(int64_t)(0), // batch_stride_bias
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
{alpha, beta}, // <- tuple of alpha and beta
mBatch}; // batch_count
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignCuda_Row_Column_Sm75::get_workspace_size(arguments);
if(workspace_size != 0) {
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
}
// Check the problem size is supported or not
cutlass::Status status = mGemmBatchedF16F16LnAlign1RCSm75.can_implement(arguments);
cutlass_check(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = mGemmBatchedF16F16LnAlign1RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
cutlass_check(status);
}
}
}
}
} else { } else {
if(mUseRRLayout) { if(mUseRRLayout) {
if(mNeedConvertMatAB) { if(mNeedConvertMatAB) {
@ -1239,32 +1243,35 @@ ErrorCode MatMulExecution::onExecute(const std::vector<Tensor *> &inputs, const
} }
} else { } else {
#ifdef ENABLE_CUDA_TUNE_PARAM #ifdef ENABLE_CUDA_TUNE_PARAM
runGemmBatchedTensorCoreFloat16Infer(&mInfo); if(mIsTuned) {
#else runGemmBatchedTensorCoreFloat16Infer(&mInfo);
if(mUseRRLayout) { }
cutlass::Status status = mGemmBatchedF16F16LnAlign8RRSm75(); #endif
cutlass_check(status); if(!mIsTuned) {
} else { if(mUseRRLayout) {
if(hAlignment) { cutlass::Status status = mGemmBatchedF16F16LnAlign8RRSm75();
if(mConvertGemmSplitK) { cutlass_check(status);
cutlass::Status status = mGemmF16F16LnAlign8Sm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmBatchedF16F16LnAlign8RCSm75();
cutlass_check(status);
}
} else { } else {
if(mConvertGemmSplitK) { if(hAlignment) {
cutlass::Status status = mGemmF16F16LnAlign1Sm75(); if(mConvertGemmSplitK) {
cutlass_check(status); cutlass::Status status = mGemmF16F16LnAlign8Sm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmBatchedF16F16LnAlign8RCSm75();
cutlass_check(status);
}
} else { } else {
cutlass::Status status = mGemmBatchedF16F16LnAlign1RCSm75(); if(mConvertGemmSplitK) {
cutlass_check(status); cutlass::Status status = mGemmF16F16LnAlign1Sm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmBatchedF16F16LnAlign1RCSm75();
cutlass_check(status);
}
} }
} }
} }
#endif
} }
// printf("normal:%d rrlayout:%d convertab:%d halign:%d\n", mFp16Fp32MixInfer, mUseRRLayout, mNeedConvertMatAB, hAlignment); // printf("normal:%d rrlayout:%d convertab:%d halign:%d\n", mFp16Fp32MixInfer, mUseRRLayout, mNeedConvertMatAB, hAlignment);
return NO_ERROR; return NO_ERROR;

View File

@ -84,6 +84,7 @@ private:
CutlassGemmInfo mGemmInfo; CutlassGemmInfo mGemmInfo;
int mBatch = 1; int mBatch = 1;
int mGpuComputeCap; int mGpuComputeCap;
bool mIsTuned = false;
bool mFp16Infer = false; bool mFp16Infer = false;
bool mFp32Infer = false; bool mFp32Infer = false;
bool mFp16Fp32MixInfer = false; bool mFp16Fp32MixInfer = false;

View File

@ -1083,9 +1083,17 @@ void BinaryBlit(uint8_t* output, const uint8_t* input, const uint8_t* input1, co
BinaryBlitTemplateFloat((float*)output, (float*)input, (float*)input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType, activationType); BinaryBlitTemplateFloat((float*)output, (float*)input, (float*)input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType, activationType);
} else if (type.bits == 16) { } else if (type.bits == 16) {
BinaryBlitTemplateFloat((half*)output, (half*)input, (half*)input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType, activationType); BinaryBlitTemplateFloat((half*)output, (half*)input, (half*)input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType, activationType);
} else {
MNN_ERROR("CUDA not supoort data code:%d, data bits:%d\n", type.code, type.bits);
} }
} else if (type.code == halide_type_int) { } else if (type.code == halide_type_int) {
BinaryBlitTemplateInt32(output, input, input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType, activationType); if(type.bits == 32) {
BinaryBlitTemplateInt32(output, input, input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType, activationType);
} else {
MNN_ERROR("CUDA not supoort data code:%d, data bits:%d\n", type.code, type.bits);
}
} else {
MNN_ERROR("CUDA not supoort data code:%d, data bits:%d\n", type.code, type.bits);
} }
} }

View File

@ -109,35 +109,38 @@ ErrorCode CutlassConvCommonExecution::runCutlassGemmFunc() {
return NO_ERROR; return NO_ERROR;
} }
#ifdef ENABLE_CUDA_TUNE_PARAM #ifdef ENABLE_CUDA_TUNE_PARAM
runGemmTensorCoreFloat16Infer(&mInfo); if(mIsTuned) {
#else runGemmTensorCoreFloat16Infer(&mInfo);
if(mActivationType == 1) { }
if(mFp16Fp32MixInfer) { #endif
cutlass::Status status = mGemmF16F32ReluSm75(); if(!mIsTuned) {
cutlass_check(status); if(mActivationType == 1) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32ReluSm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16ReluSm75();
cutlass_check(status);
}
} else if(mActivationType == 2) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32Relu6Sm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16Relu6Sm75();
cutlass_check(status);
}
} else { } else {
cutlass::Status status = mGemmF16F16ReluSm75(); if(mFp16Fp32MixInfer) {
cutlass_check(status); cutlass::Status status = mGemmF16F32LnSm75();
} cutlass_check(status);
} else if(mActivationType == 2) { } else {
if(mFp16Fp32MixInfer) { cutlass::Status status = mGemmF16F16LnSm75();
cutlass::Status status = mGemmF16F32Relu6Sm75(); cutlass_check(status);
cutlass_check(status); }
} else {
cutlass::Status status = mGemmF16F16Relu6Sm75();
cutlass_check(status);
}
} else {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32LnSm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16LnSm75();
cutlass_check(status);
} }
} }
#endif
return NO_ERROR; return NO_ERROR;
} }

View File

@ -94,6 +94,7 @@ protected:
GemmTensor_BF16_BF16_Relu6_AlignTensor_Sm80 mGemmBF16BF16Relu6Sm80; GemmTensor_BF16_BF16_Relu6_AlignTensor_Sm80 mGemmBF16BF16Relu6Sm80;
#endif #endif
int mGpuComputeCap = 75; int mGpuComputeCap = 75;
bool mIsTuned = false;
int mActivationType = 0; int mActivationType = 0;
bool mFp16Infer = false; bool mFp16Infer = false;
bool mFp32Infer = false; bool mFp32Infer = false;

View File

@ -871,6 +871,17 @@ const char* shader_MetalBackend_metal =
" uint4 extent;//dstStride[3]+dstOffset\n" " uint4 extent;//dstStride[3]+dstOffset\n"
" uint4 imageSize;\n" " uint4 imageSize;\n"
"};\n" "};\n"
"struct MemsetInfo {\n"
" int4 V;\n"
" uint4 size;\n"
"};\n"
"kernel void fill_intx4(device int4 *out [[buffer(0)]],\n"
" constant MemsetInfo &info [[buffer(1)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<info.size.x) {\n"
" out[gid.x]=info.V;\n"
" }\n"
"}\n"
"kernel void blit_intx4(const device int4 *in [[buffer(0)]],\n" "kernel void blit_intx4(const device int4 *in [[buffer(0)]],\n"
" device int4 *out [[buffer(1)]],\n" " device int4 *out [[buffer(1)]],\n"
" constant SamplerInfo &info [[buffer(2)]],\n" " constant SamplerInfo &info [[buffer(2)]],\n"
@ -1776,34 +1787,6 @@ const char* shader_MetalROIPooling_metal =
" out[int(gid.z)*s.output_size+int(gid.y)*s.output_width+int(gid.x)]=max4;\n" " out[int(gid.z)*s.output_size+int(gid.y)*s.output_width+int(gid.x)]=max4;\n"
"}\n" "}\n"
; ;
const char* shader_MetalCast_metal =
"using namespace metal;\n"
"kernel void cast_float_to_int32(const device M *in [[buffer(0)]],\n"
" device int *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=int(in[int(gid)]);\n"
"}\n"
"kernel void cast_int32_to_float(const device int *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=M(in[int(gid)]);\n"
"}\n"
"kernel void cast_uint8_to_float(const device uchar *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=M(in[int(gid)]);\n"
"}\n"
"kernel void cast_uint8_to_int(const device uchar *in [[buffer(0)]],\n"
" device int *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=M(in[int(gid)]);\n"
"}\n"
"kernel void cast_float_to_uint8(const device M *in [[buffer(0)]],\n"
" device uchar *out [[buffer(1)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" out[int(gid)]=uchar(in[int(gid)]);\n"
"}\n"
;
const char* shader_MetalConvolution1x1_metal = const char* shader_MetalConvolution1x1_metal =
"#define CONV_UNROLL (4)\n" "#define CONV_UNROLL (4)\n"
"#define CONV_UNROLL_L (8)\n" "#define CONV_UNROLL_L (8)\n"
@ -2200,68 +2183,6 @@ const char* shader_MetalConvolution1x1_metal =
"}\n" "}\n"
; ;
const char* shader_MetalConvolutionGEMM_metal = const char* shader_MetalConvolutionGEMM_metal =
"struct conv_im2col_cst {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int batch;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv_im2col(const device M4 *im [[buffer(0)]],\n"
" device M4 *cols [[buffer(1)]],\n"
" constant conv_im2col_cst& cst [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto z=gid.z % cst.input_slice;\n"
" auto b=gid.z/cst.input_slice;\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height && (int)b<cst.batch) {\n"
" int offset_x=gid.x*cst.stride_x-cst.pad_x;\n"
" int offset_y=gid.y*cst.stride_y-cst.pad_y;\n"
" int index=b*cst.output_size+gid.y*cst.output_width+gid.x;\n"
" int cols_y=index/4;\n"
" int cols_x=index % 4+z*cst.kernel_size*4;\n"
" \n"
" auto xy_cols=cols+cols_y*cst.kernel_size*cst.input_slice*4+cols_x;\n"
" auto xy_im=im+b*cst.input_size*cst.input_slice+z*cst.input_size;\n"
" for (int ky=0,src_y=offset_y; ky<cst.kernel_y; ky++,src_y += cst.dilation_y) {\n"
" for (int kx=0,src_x=offset_x; kx<cst.kernel_x; kx++,src_x += cst.dilation_x) {\n"
" auto pad=src_x<0 || src_y<0 || src_x >= cst.input_width || src_y >= cst.input_height;\n"
" xy_cols[(ky*cst.kernel_x+kx)*4]=pad ? 0 : xy_im[src_y*cst.input_width+src_x];\n"
" }\n"
" }\n"
" }\n"
"}\n"
"kernel void conv_col2im(const device M4 *cols [[buffer(0)]],\n"
" device M4 *im [[buffer(1)]],\n"
" const device M4 *biasTerms [[buffer(2)]],\n"
" constant conv_im2col_cst& cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto z=gid.z % cst.output_slice;\n"
" auto b=gid.z/cst.output_slice;\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height && (int)b<cst.batch) {\n"
" int index=b*cst.output_size+gid.y*cst.output_width+gid.x;\n"
" auto src_x=index/4;\n"
" auto src_y=index % 4+z*4;\n"
" auto src_y_stride=UP_DIV(cst.output_size*cst.batch,4);\n"
" \n"
" auto v=cols[(int)src_y*src_y_stride+(int)src_x]+biasTerms[(int)z];\n"
" im[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(v,cst.activation);\n"
" }\n"
"}\n"
"struct matmul4x4_const {\n" "struct matmul4x4_const {\n"
" int output_width;\n" " int output_width;\n"
" int output_height;\n" " int output_height;\n"
@ -2428,8 +2349,6 @@ const char* shader_MetalDefine_metal =
"// \n" "// \n"
"#define UP_DIV(x,y) ( ((x)+(y)-1)/(y) )\n" "#define UP_DIV(x,y) ( ((x)+(y)-1)/(y) )\n"
"#define ROUND_UP(x,y) ( ((x)+(y)-1)/(y)*(y) )\n" "#define ROUND_UP(x,y) ( ((x)+(y)-1)/(y)*(y) )\n"
"// whether store with float32\n"
"#define MNN_METAL_FULL_PRECISION 0 // should edit in .h too\n"
"// whether computer with float32 when store with float16\n" "// whether computer with float32 when store with float16\n"
"#define MNN_METAL_FLOAT32_COMPUTER 1 //\n" "#define MNN_METAL_FLOAT32_COMPUTER 1 //\n"
"#if MNN_METAL_FULL_PRECISION\n" "#if MNN_METAL_FULL_PRECISION\n"

View File

@ -16,7 +16,6 @@ extern const char* shader_MetalScale_metal;
extern const char* shader_MetalDeconvolution_metal; extern const char* shader_MetalDeconvolution_metal;
extern const char* shader_MetalPooling_metal; extern const char* shader_MetalPooling_metal;
extern const char* shader_MetalROIPooling_metal; extern const char* shader_MetalROIPooling_metal;
extern const char* shader_MetalCast_metal;
extern const char* shader_MetalConvolution1x1_metal; extern const char* shader_MetalConvolution1x1_metal;
extern const char* shader_MetalConvolutionGEMM_metal; extern const char* shader_MetalConvolutionGEMM_metal;
extern const char* shader_MetalResize_metal; extern const char* shader_MetalResize_metal;

View File

@ -1,4 +1,8 @@
FILE(GLOB MNN_Metal_SRC ${CMAKE_CURRENT_LIST_DIR}/*.mm ${CMAKE_CURRENT_LIST_DIR}/*.hpp ${CMAKE_CURRENT_LIST_DIR}/*.h ${CMAKE_CURRENT_LIST_DIR}/*.cpp) FILE(GLOB MNN_Metal_SRC ${CMAKE_CURRENT_LIST_DIR}/*.mm ${CMAKE_CURRENT_LIST_DIR}/*.hpp ${CMAKE_CURRENT_LIST_DIR}/*.h ${CMAKE_CURRENT_LIST_DIR}/*.cpp)
IF(MNN_SUPPORT_RENDER)
file(GLOB MNN_Metal_Render_SRC ${CMAKE_CURRENT_LIST_DIR}/render/*.mm ${CMAKE_CURRENT_LIST_DIR}/render/*.hpp ${CMAKE_CURRENT_LIST_DIR}/render/*.cpp)
list(APPEND MNN_Metal_SRC ${MNN_Metal_Render_SRC})
ENDIF()
FILE(GLOB MNN_Metal_KERNELS_SRC ${CMAKE_CURRENT_LIST_DIR}/*.metal) FILE(GLOB MNN_Metal_KERNELS_SRC ${CMAKE_CURRENT_LIST_DIR}/*.metal)
option(MNN_METALLIB_SOURCE "Use Metal Source Directly" ON) option(MNN_METALLIB_SOURCE "Use Metal Source Directly" ON)
add_library(MNNMetal OBJECT ${MNN_Metal_SRC} "${CMAKE_CURRENT_LIST_DIR}/MetalOPRegister.mm") add_library(MNNMetal OBJECT ${MNN_Metal_SRC} "${CMAKE_CURRENT_LIST_DIR}/MetalOPRegister.mm")

View File

@ -42,6 +42,7 @@ typedef struct {
@property (strong, nonatomic, readonly) id<MTLDevice> device; @property (strong, nonatomic, readonly) id<MTLDevice> device;
/** max memory length cound be used in threadgroup */ /** max memory length cound be used in threadgroup */
@property (assign, nonatomic, readonly) BOOL isCommitEachShader; @property (assign, nonatomic, readonly) BOOL isCommitEachShader;
@property (assign, nonatomic, readonly) BOOL isIphone;
/** /**
* @brief alloc temp buffer on device * @brief alloc temp buffer on device
@ -60,19 +61,6 @@ typedef struct {
*/ */
- (id<MTLBuffer>)newDeviceBuffer:(NSUInteger)size bytes:(const void *)bytes access:(MNN::MetalAccess)access; - (id<MTLBuffer>)newDeviceBuffer:(NSUInteger)size bytes:(const void *)bytes access:(MNN::MetalAccess)access;
/**
* @brief create compute encoder on default command buffer
* @return created encoder
*/
- (id<MTLComputeCommandEncoder>)encoder;
- (id<MTLComputeCommandEncoder>)encoder_net;
/**
* @brief create fill encoder on default command buffer
* @return created encoder
*/
- (id<MTLBlitCommandEncoder>)encoderBlit;
- (id<MTLBlitCommandEncoder>)encoderBlit_net;
/** /**
* @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline. * @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline.
@ -80,7 +68,7 @@ typedef struct {
* @param encoder command encoder * @param encoder command encoder
* @return bandwidth info for function * @return bandwidth info for function
*/ */
- (MNN::MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder; - (MNN::MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder fp16:(BOOL)fp16;
/** /**
* @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline. * @brief load encoder with function name. returns maxTotalThreadsPerThreadgroup of pipeline.
@ -88,22 +76,15 @@ typedef struct {
* @param encoder command encoder * @param encoder command encoder
* @return bandwidth info for function * @return bandwidth info for function
*/ */
- (id<MTLCommandBuffer>) newCmdBuffer:(MTLSize) localIndex; - (id<MTLCommandBuffer>) newCmdBuffer:(MTLSize) localIndex queue:(id<MTLCommandQueue>) cmdqueue;
- (NSUInteger)timeUsed:(id<MTLCommandBuffer>) buffer; - (NSUInteger)timeUsed:(id<MTLCommandBuffer>) buffer;
- (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MNN::MetalRuntime *) rt shaderName:(std::string) kernelName; - (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MNN::MetalRuntime *) rt shaderName:(std::string) kernelName queue:(id<MTLCommandQueue>) cmdqueue;
- (NSUInteger)PipelinetimeUsed: (id<MTLComputePipelineState>)pipeline global:(MTLSize)globals local:(MTLSize)locals loop:(NSUInteger)count buffer:(NSArray *)buffers queue:(id<MTLCommandQueue>) cmdqueue;
- (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device; - (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device;
/**
* @brief commit commands
*/
- (void)commit;
- (void)commit_net;
/**
* @brief wait for completion
*/
- (void)wait;
/** /**
* @brief dispatch encoder with default settings * @brief dispatch encoder with default settings
@ -126,8 +107,8 @@ typedef struct {
threads:(MTLSize)threads threads:(MTLSize)threads
threadsPerGroup:(MTLSize)threadsPerGroup threadsPerGroup:(MTLSize)threadsPerGroup
bandwidth:(MNN::MetalBandwidth)bandwidth; bandwidth:(MNN::MetalBandwidth)bandwidth;
- (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name; - (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name fp16:(BOOL)fp16;
- (id<MTLComputePipelineState>)pipelineWithSource:(NSString *)source name:(NSString *)name; - (id<MTLComputePipelineState>)pipelineWithSourceOption:(NSString *)source name:(NSString *)name options:(MTLCompileOptions *)options;
- (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) pipeline threads:(MTLSize)threads; - (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) pipeline threads:(MTLSize)threads;
- (std::pair<MTLSize, MTLSize>)computeBestGroupAndLocal:(id<MTLComputePipelineState>) bw threads:(MTLSize)t; - (std::pair<MTLSize, MTLSize>)computeBestGroupAndLocal:(id<MTLComputePipelineState>) bw threads:(MTLSize)t;

View File

@ -22,18 +22,15 @@ using namespace MNN;
@interface MNNMetalContext () @interface MNNMetalContext ()
// public // public
@property (strong, nonatomic) id<MTLDevice> device; @property (strong, nonatomic) id<MTLDevice> device;
@property (strong, nonatomic) id<MTLCommandQueue> commandQueue; @property (assign, nonatomic) BOOL isIphone;
@property (strong, nonatomic) id<MTLCommandBuffer> commandBuffer;
@property (strong, nonatomic) id<MTLCommandBuffer> commandBuffer_net;
// private // private
@property (strong, nonatomic) NSMutableDictionary<NSString *, id<MTLComputePipelineState>> *caches; @property (strong, nonatomic) NSMutableDictionary<NSString *, id<MTLComputePipelineState>> *cachesFp32;
@property (strong, nonatomic) NSMutableArray<id<MTLCommandBuffer>> *waitings; @property (strong, nonatomic) NSMutableDictionary<NSString *, id<MTLComputePipelineState>> *cachesFp16;
@property (strong, nonatomic) NSMutableDictionary<NSString *, id<MTLLibrary>>* library;
@end @end
@implementation MNNMetalContext @implementation MNNMetalContext
static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *, id<MTLLibrary>>* libraryMap) { static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *, id<MTLComputePipelineState>>* libraryMap, bool usefp16) {
AUTOTIME; AUTOTIME;
ShaderMap shader; ShaderMap shader;
auto first = shader.search("shader_MetalDefine_metal"); auto first = shader.search("shader_MetalDefine_metal");
@ -47,6 +44,11 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
if (iter.first == "shader_MetalConvolutionActivation_metal") { if (iter.first == "shader_MetalConvolutionActivation_metal") {
continue; continue;
} }
if (!usefp16) {
total << "#define MNN_METAL_FULL_PRECISION 1\n";
} else {
total << "#define MNN_METAL_FULL_PRECISION 0\n";
}
total << first << "\n" << second << "\n" << iter.second; total << first << "\n" << second << "\n" << iter.second;
auto totalString = total.str(); auto totalString = total.str();
auto totalNSString = [[NSString alloc] initWithUTF8String:totalString.c_str()]; auto totalNSString = [[NSString alloc] initWithUTF8String:totalString.c_str()];
@ -64,7 +66,15 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
} }
auto functionNames = [library functionNames]; auto functionNames = [library functionNames];
for(int i=0; i<functionNames.count ; i++) { for(int i=0; i<functionNames.count ; i++) {
libraryMap[functionNames[i]] = library; id<MTLFunction> function = [library newFunctionWithName:functionNames[i]];
if (!function) {
MNN_ERROR("Create Function in metal error\n");
continue;
}
NSError *error = nil;
auto result = [device newComputePipelineStateWithFunction:function error:&error];
libraryMap[functionNames[i]] = result;
} }
} }
} }
@ -96,19 +106,29 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return NO; return NO;
} }
+ (BOOL)isIphone{
struct utsname systemInfo;
uname(&systemInfo);
NSString *deviceString = [NSString stringWithCString:systemInfo.machine encoding:NSASCIIStringEncoding];
NSString *subString = @"iPhone";
NSRange range = [deviceString rangeOfString:subString];
if (range.location != NSNotFound) {
return YES;
}
return NO;
}
- (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device { - (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device {
MNN_ASSERT(nullptr != context); MNN_ASSERT(nullptr != context);
_device = context->device; _device = context->device;
_library = [NSMutableDictionary dictionary]; _cachesFp16 = [NSMutableDictionary dictionary];
createLibrary(_device, _library); _cachesFp32 = [NSMutableDictionary dictionary];
_commandQueue = context->queue;
_commandBuffer = [_commandQueue commandBuffer];
_commandBuffer_net = [_commandQueue commandBuffer];
_caches = [NSMutableDictionary dictionary];
_waitings = [NSMutableArray array];
_isCommitEachShader = self.class.commit_frequent; _isCommitEachShader = self.class.commit_frequent;
_isIphone = self.class.isIphone;
return (0 != [_library count]); createLibrary(_device, _cachesFp16, true);
createLibrary(_device, _cachesFp32, false);
return nil != _device;
} }
- (instancetype)init { - (instancetype)init {
@ -139,42 +159,16 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return [_device newBufferWithBytes:bytes length:size options:[self optionForAccess:access]]; return [_device newBufferWithBytes:bytes length:size options:[self optionForAccess:access]];
} }
#pragma mark enqueue - (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name fp16:(BOOL)fp16 {
- (id<MTLFunction>)functionWithName:(NSString *)name { if (fp16) {
if (!name) return _cachesFp16[name];
return nil; }
auto lib = _library[name]; return _cachesFp32[name];
id<MTLFunction> result = [lib newFunctionWithName:name];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
if (@available(iOS 10.0, *))
result.label = name;
#endif
return result;
} }
- (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name { - (id<MTLComputePipelineState>)pipelineWithSourceOption:(NSString *)source name:(NSString *)name options:(MTLCompileOptions *)options {
id<MTLComputePipelineState> result = _caches[name];
if (result)
return result;
id<MTLFunction> function = [self functionWithName:name];
if (!function)
return nil;
NSError *error = nil;
result = [_device newComputePipelineStateWithFunction:function error:&error];
#if MNN_METAL_DEBUG
if (error)
printf("[METAL] create pipeline error: %s\n", error.localizedDescription.UTF8String);
#endif
if (result)
_caches[name] = result;
return result;
}
- (id<MTLComputePipelineState>)pipelineWithSource:(NSString *)source name:(NSString *)name {
NSError *err = nil; NSError *err = nil;
auto library = [_device newLibraryWithSource:source options:nil error:&err]; auto library = [_device newLibraryWithSource:source options:options error:&err];
if (nil == library) { if (nil == library) {
if (err) { if (err) {
NSLog(@"Warning: pipelineWithSource error: %@", err); NSLog(@"Warning: pipelineWithSource error: %@", err);
@ -184,43 +178,11 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
id<MTLFunction> function = [library newFunctionWithName:name]; id<MTLFunction> function = [library newFunctionWithName:name];
NSError *error = nil; NSError *error = nil;
id<MTLComputePipelineState> result = [_device newComputePipelineStateWithFunction:function error:&error]; id<MTLComputePipelineState> result = [_device newComputePipelineStateWithFunction:function error:&error];
if (result)
_caches[name] = result;
return result; return result;
} }
- (id<MTLComputeCommandEncoder>)encoder { - (MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder fp16:(BOOL)fp16 {
id<MTLComputeCommandEncoder> result = [_commandBuffer computeCommandEncoder]; id<MTLComputePipelineState> pipeline = [self pipelineWithName:name fp16:fp16];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
result.label = nil;
#endif
return result;
}
- (id<MTLBlitCommandEncoder>)encoderBlit {
id<MTLBlitCommandEncoder> result = [_commandBuffer blitCommandEncoder];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
result.label = nil;
#endif
return result;
}
- (id<MTLComputeCommandEncoder>)encoder_net {
id<MTLComputeCommandEncoder> result = [_commandBuffer_net computeCommandEncoder];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
result.label = nil;
#endif
return result;
}
- (id<MTLBlitCommandEncoder>)encoderBlit_net {
id<MTLBlitCommandEncoder> result = [_commandBuffer_net blitCommandEncoder];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
result.label = nil;
#endif
return result;
}
- (MetalBandwidth)load:(NSString *)name encoder:(id<MTLComputeCommandEncoder>)encoder {
id<MTLComputePipelineState> pipeline = [self pipelineWithName:name];
MNN_ASSERT(nil != pipeline); MNN_ASSERT(nil != pipeline);
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK #if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
@ -238,13 +200,6 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return {pipeline.threadExecutionWidth, pipeline.maxTotalThreadsPerThreadgroup, NO}; return {pipeline.threadExecutionWidth, pipeline.maxTotalThreadsPerThreadgroup, NO};
} }
- (id<MTLCommandBuffer>) newCmdBuffer:(MTLSize) localIndex {
id<MTLCommandBuffer> cmdBuffer = [_commandQueue commandBuffer]; // create a new command buffer
std::string label = std::to_string((int)localIndex.width) + "_" + std::to_string((int)localIndex.height) + "_" + std::to_string((int)localIndex.depth);
cmdBuffer.label = [NSString stringWithCString:label.c_str() encoding:[NSString defaultCStringEncoding]];
return cmdBuffer;
}
- (NSUInteger)timeUsed:(id<MTLCommandBuffer>)buffer { - (NSUInteger)timeUsed:(id<MTLCommandBuffer>)buffer {
// Get ns precision time // Get ns precision time
auto start = mach_absolute_time(); auto start = mach_absolute_time();
@ -256,8 +211,14 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return (end-start)/1000; return (end-start)/1000;
} }
- (id<MTLCommandBuffer>) newCmdBuffer:(MTLSize) localIndex queue:(id<MTLCommandQueue>) cmdqueue {
id<MTLCommandBuffer> cmdBuffer = [cmdqueue commandBuffer]; // create a new command buffer
std::string label = std::to_string((int)localIndex.width) + "_" + std::to_string((int)localIndex.height) + "_" + std::to_string((int)localIndex.depth);
cmdBuffer.label = [NSString stringWithCString:label.c_str() encoding:[NSString defaultCStringEncoding]];
return cmdBuffer;
}
- (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MetalRuntime *) rt shaderName:(std::string) kernelName { - (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MetalRuntime *) rt shaderName:(std::string) kernelName queue:(id<MTLCommandQueue>) cmdqueue {
NSUInteger gid_x = threads.width; NSUInteger gid_x = threads.width;
NSUInteger gid_y = threads.height; NSUInteger gid_y = threads.height;
NSUInteger gid_z = threads.depth; NSUInteger gid_z = threads.depth;
@ -289,7 +250,7 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
{ {
//get original trick time //get original trick time
{ {
id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:thread.second]; id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:thread.second queue:cmdqueue];
id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder]; id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder];
int loop = count; int loop = count;
@ -344,7 +305,7 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
} }
MTLSize local = {x, y, z}; MTLSize local = {x, y, z};
MTLSize global = {UP_DIV(gid_x, x), UP_DIV(gid_y, y), UP_DIV(gid_z, z)}; MTLSize global = {UP_DIV(gid_x, x), UP_DIV(gid_y, y), UP_DIV(gid_z, z)};
id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:local]; id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:local queue:cmdqueue];
id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder]; id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder];
int loop = count; int loop = count;
@ -388,50 +349,27 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return std::make_tuple(thread.first, thread.second, min_time); return std::make_tuple(thread.first, thread.second, min_time);
} }
#pragma mark dispatch
- (void)commit {
if (_commandBuffer.status < MTLCommandBufferStatusCommitted) {
[_commandBuffer commit];
[_waitings addObject:_commandBuffer];
_commandBuffer = [_commandQueue commandBuffer]; // create a new command buffer
}
}
- (void)commit_net { - (NSUInteger)PipelinetimeUsed: (id<MTLComputePipelineState>)pipeline global:(MTLSize)globals local:(MTLSize)locals loop:(NSUInteger)count buffer:(NSArray *)buffers queue:(id<MTLCommandQueue>) cmdqueue{
if (_commandBuffer_net.status < MTLCommandBufferStatusCommitted) { NSUInteger time = 0;
[_commandBuffer_net commit]; MTLSize local_size = {locals.width, locals.height, locals.depth};
[_waitings addObject:_commandBuffer_net]; MTLSize global_size = {globals.width, globals.height, globals.depth};
_commandBuffer_net = [_commandQueue commandBuffer]; // create a new command buffer id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:local_size queue:cmdqueue];
} id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder];
}
- (void)wait { int loop = count;
for (id<MTLCommandBuffer> buffer in _waitings) { while(loop--) {
if (buffer.status >= MTLCommandBufferStatusCompleted) [encoder setComputePipelineState:pipeline];
continue; for(NSUInteger idx = 0; idx < buffers.count; idx++) {
[encoder setBuffer:[buffers objectAtIndex:idx] offset:0 atIndex:idx];
#if MNN_METAL_BENCHMARK
NSTimeInterval begin = [NSDate timeIntervalSinceReferenceDate];
[buffer waitUntilCompleted];
NSTimeInterval end = [NSDate timeIntervalSinceReferenceDate];
if (@available(iOS 10.3, *)) {
printf("[METAL] commit costs: %.3fms\t(kernel: %.3fms, GPU: %.3fms)\n", (end - begin) * 1000.f,
(buffer.kernelEndTime - buffer.kernelStartTime) * 1000.f,
(buffer.GPUEndTime - buffer.GPUStartTime) * 1000.f);
} else {
printf("[METAL] commit costs: %.3fms\n", (end - begin) * 1000.f);
} }
#else
[buffer waitUntilCompleted];
#endif
#if MNN_METAL_DEBUG [encoder dispatchThreadgroups:global_size threadsPerThreadgroup:local_size];
if (buffer.error) {
printf("[METAL] %s\n", buffer.error.localizedDescription.UTF8String);
}
#endif
} }
[_waitings removeAllObjects]; [encoder endEncoding];
time = [self timeUsed :commamd_buffer];
return time;
} }
static NSUInteger smallest_log2(NSUInteger integer) { static NSUInteger smallest_log2(NSUInteger integer) {
@ -663,7 +601,7 @@ void printBuffer(const void *content, unsigned long bytes, const char *fmt) {
} }
} else if (type == halide_type_float) { } else if (type == halide_type_float) {
if (bits == 16) { // half if (bits == 16) { // half
printBuffer<metal_float>(bytes, length, "%.4f"); printBuffer<__fp16>(bytes, length, "%.4f");
} else { // float } else { // float
printBuffer<float>(bytes, length, "%.4f"); printBuffer<float>(bytes, length, "%.4f");
} }

View File

@ -37,6 +37,10 @@ public:
} }
void setGpuMode(const int cl_mode_num); void setGpuMode(const int cl_mode_num);
void setCommandQueue(id<MTLCommandQueue> queue);
id<MTLCommandQueue> getCommandQueue() const {
return mQueue;
}
std::pair<const void*, size_t> makeCache(TunedInfo* info); std::pair<const void*, size_t> makeCache(TunedInfo* info);
bool setCache(std::pair<const void*, size_t> cache); bool setCache(std::pair<const void*, size_t> cache);
@ -70,10 +74,12 @@ private:
std::map<std::pair<std::string, std::vector<uint32_t>>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>> mTunedThreadGroup; std::map<std::pair<std::string, std::vector<uint32_t>>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>> mTunedThreadGroup;
private: private:
id<MTLCommandQueue> mQueue = nil;
std::vector<uint8_t> mBuffer; std::vector<uint8_t> mBuffer;
const void* mCacheOutside = nullptr; const void* mCacheOutside = nullptr;
size_t mCacheOutsideSize = 0; size_t mCacheOutsideSize = 0;
TunedInfo* mTunedInfo; TunedInfo* mTunedInfo;
BackendConfig mDefaultConfig;
}; };
@ -124,11 +130,13 @@ public:
* @param creator registering creator. * @param creator registering creator.
*/ */
static void addCreator(OpType type, Creator *creator); static void addCreator(OpType type, Creator *creator);
size_t getTensorSizeInBytes(const Tensor* tensor) const;
id<MTLBuffer> getHostBuffer(size_t size) const; id<MTLBuffer> getHostBuffer(size_t size) const;
id<MTLBuffer> getConstBuffer(size_t size) const; id<MTLBuffer> getConstBuffer(size_t size) const;
id<MTLComputePipelineState> makeComputePipelineWithSourceOption(const char* csource, const char* cname, MTLCompileOptions *options) const;
public: public:
MetalBackend(std::shared_ptr<EagerBufferAllocator> staticMem, const MetalRuntime* runtime); MetalBackend(std::shared_ptr<EagerBufferAllocator> staticMem, const MetalRuntime* runtime, bool usefp16AsFp32);
virtual ~MetalBackend(); virtual ~MetalBackend();
const MetalRuntime* runtime() const { const MetalRuntime* runtime() const {
return mRuntime; return mRuntime;
@ -146,6 +154,7 @@ public:
virtual void onExecuteBegin() const override; virtual void onExecuteBegin() const override;
virtual void onExecuteEnd() const override; virtual void onExecuteEnd() const override;
virtual int onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) override; virtual int onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) override;
virtual bool onGetTensorInfo(const Tensor* tensor, void* dstInfo) override;
public: public:
/** /**
@ -164,7 +173,7 @@ public:
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const; id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const;
void flushEncoder() const; void flushEncoder() const;
id<MTLComputeCommandEncoder> encoder() const; id<MTLComputeCommandEncoder> encoder_for_net() const;
void addOpEncoder(std::function<void(void)> opEncoder); void addOpEncoder(std::function<void(void)> opEncoder);
bool isCommandEncoderSet(); bool isCommandEncoderSet();
@ -178,15 +187,36 @@ public:
} }
bool isCmdBufferCommit(); bool isCmdBufferCommit();
bool isIphone(){
return mIsIphone;
}
void commit() const;
void commit_net() const;
void wait() const;
id<MTLCommandQueue> queue() const {
return _commandQueue;
}
bool useFp16InsteadFp32() const {
return mUseFloatAsFp16;
}
private: private:
id<MTLCommandBuffer> getCommandBufferForBufferCopy() const;
id<MTLCommandBuffer> getCommandBufferForNet() const;
id<MTLComputeCommandEncoder> encoder_net() const;
mutable id<MTLCommandBuffer> _commandBuffer = nil;
mutable id<MTLCommandBuffer> _commandBuffer_net = nil;
mutable id<MTLCommandBuffer> _waiting = nil;
id<MTLCommandQueue> _commandQueue;
const MetalRuntime* mRuntime; const MetalRuntime* mRuntime;
std::vector<id<MTLBuffer>> mHoldBuffers; std::vector<id<MTLBuffer>> mHoldBuffers;
id<MTLBuffer> mShapeH2D; id<MTLBuffer> mShapeH2D;
id<MTLBuffer> mShapeD2H; id<MTLBuffer> mShapeD2H;
mutable NSUInteger mEncoderCount = 0; mutable NSUInteger mEncoderCount = 0;
mutable bool mOpEncoderSet = false;//whether has set encoder mutable bool mOpEncoderSet = false;//whether has set encoder
mutable bool mOpFullSupport = true; mutable bool mSupportDeferEncode = true;
mutable bool mFrameEncodeCache = false; mutable bool mFrameEncodeCache = false;
std::vector<std::function<void(void)>> mOpEncoders; std::vector<std::function<void(void)>> mOpEncoders;
@ -199,6 +229,8 @@ private:
void onCopyHostToDevice(const Tensor *src, const Tensor *dst) const; void onCopyHostToDevice(const Tensor *src, const Tensor *dst) const;
void onCopyDeviceToHost(const Tensor *src, const Tensor *dst) const; void onCopyDeviceToHost(const Tensor *src, const Tensor *dst) const;
void onCopyDeviceToDevice(const Tensor *src, const Tensor *dst, id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const; void onCopyDeviceToDevice(const Tensor *src, const Tensor *dst, id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const;
bool mUseFloatAsFp16;
bool mIsIphone = false;
}; };

View File

@ -34,6 +34,9 @@ struct TunedInfo {
}; };
void registerMetalOps(); void registerMetalOps();
#ifdef MNN_SUPPORT_RENDER
extern void registerMetalRenderOps();
#endif
static inline std::map<OpType, MetalBackend::Creator *> *getCreatorMap() { static inline std::map<OpType, MetalBackend::Creator *> *getCreatorMap() {
static std::once_flag of; static std::once_flag of;
@ -50,17 +53,40 @@ void MetalBackend::addCreator(OpType t, Creator *c) {
map->insert(std::make_pair(t, c)); map->insert(std::make_pair(t, c));
} }
MetalBackend::MetalBackend(std::shared_ptr<EagerBufferAllocator> staticMem, const MetalRuntime* runtime) : Backend(MNN_FORWARD_METAL) { MetalBackend::MetalBackend(std::shared_ptr<EagerBufferAllocator> staticMem, const MetalRuntime* runtime, bool usefp16AsFp32) : Backend(MNN_FORWARD_METAL) {
mRuntime = runtime; mRuntime = runtime;
mBufferPool.reset(new EagerBufferAllocator(EagerBufferAllocator::Allocator::createRecurse(staticMem.get()), 1024)); mBufferPool.reset(new EagerBufferAllocator(EagerBufferAllocator::Allocator::createRecurse(staticMem.get()), 1024));
mStaticBufferPool = staticMem; mStaticBufferPool = staticMem;
mShapeH2D = getConstBuffer(4 * sizeof(int)); mShapeH2D = getConstBuffer(4 * sizeof(int));
mShapeD2H = getConstBuffer(4 * sizeof(int)); mShapeD2H = getConstBuffer(4 * sizeof(int));
mOpFullSupport = true; mUseFloatAsFp16 = usefp16AsFp32;
auto ctx = (__bridge MNNMetalContext *)context();
mIsIphone = ctx.isIphone;
if (runtime->getCommandQueue() == nil) {
// one command queue can create only a few command buffer, so let each backend own a command queue
_commandQueue = [[ctx device] newCommandQueue];
mSupportDeferEncode = true;
} else {
// otherwise forbid defer encode optimize
_commandQueue = runtime->getCommandQueue();
mSupportDeferEncode = false;
}
_commandBuffer = nil;
_commandBuffer_net = nil;
_waiting = nil;
} }
MetalBackend::~MetalBackend() { MetalBackend::~MetalBackend() {
// Do nothing flushEncoder();
} }
id<MTLComputeCommandEncoder> MetalBackend::encoder_net() const {
id<MTLComputeCommandEncoder> result = [getCommandBufferForNet() computeCommandEncoder];
#if MNN_METAL_DEBUG || MNN_METAL_BENCHMARK
result.label = nil;
#endif
return result;
}
void *MetalBackend::context() const { void *MetalBackend::context() const {
return mRuntime->context(); return mRuntime->context();
} }
@ -81,8 +107,7 @@ private:
MemChunk mBuffer; MemChunk mBuffer;
EagerBufferAllocator* mAllocator; EagerBufferAllocator* mAllocator;
}; };
Backend::MemObj* MetalBackend::onAcquire(const Tensor *_tensor, StorageType storageType) { size_t MetalBackend::getTensorSizeInBytes(const Tensor* tensor) const {
auto tensor = const_cast<Tensor *>(_tensor);
auto format = TensorUtils::getDescribe(tensor)->dimensionFormat; auto format = TensorUtils::getDescribe(tensor)->dimensionFormat;
size_t size; size_t size;
if (MNN_DATA_FORMAT_NC4HW4 == format && tensor->dimensions() >= 2) { if (MNN_DATA_FORMAT_NC4HW4 == format && tensor->dimensions() >= 2) {
@ -107,16 +132,25 @@ Backend::MemObj* MetalBackend::onAcquire(const Tensor *_tensor, StorageType stor
size = ROUND_UP(size, 4); size = ROUND_UP(size, 4);
} }
if (0 == size) { if (0 == size) {
return nullptr; return 0;
} }
// use metal_float when meets float // use metal_float when meets float
if (halide_type_float == tensor->buffer().type.code && tensor->buffer().type.bits == 32) { if (halide_type_float == tensor->buffer().type.code && tensor->buffer().type.bits == 32 && mUseFloatAsFp16) {
size*= sizeof(metal_float); size *= 2;
} else { } else {
size *= tensor->getType().bytes(); size *= tensor->getType().bytes();
} }
size_t align = 4 * sizeof(int);
size = ROUND_UP(size, align);
return size;
}
Backend::MemObj* MetalBackend::onAcquire(const Tensor *_tensor, StorageType storageType) {
auto tensor = const_cast<Tensor *>(_tensor);
size_t size = getTensorSizeInBytes(_tensor);
if (0 == size) {
return nullptr;
}
// reuse if possible // reuse if possible
MemChunk buffer; MemChunk buffer;
EagerBufferAllocator* allocator = nullptr; EagerBufferAllocator* allocator = nullptr;
@ -159,7 +193,7 @@ Execution *MetalBackend::onCreate(const std::vector<Tensor *> &inputs, const std
auto map = getCreatorMap(); auto map = getCreatorMap();
auto iter = map->find(op->type()); auto iter = map->find(op->type());
if (iter == map->end()) { if (iter == map->end()) {
mOpFullSupport = false; mSupportDeferEncode = false;
if (nullptr != op->name()) { if (nullptr != op->name()) {
MNN_PRINT("Don't support type [%s], %s\n", EnumNameOpType(op->type()), op->name()->c_str()); MNN_PRINT("Don't support type [%s], %s\n", EnumNameOpType(op->type()), op->name()->c_str());
} else { } else {
@ -170,7 +204,7 @@ Execution *MetalBackend::onCreate(const std::vector<Tensor *> &inputs, const std
auto exe = iter->second->onCreate(inputs, op, this, outputs); auto exe = iter->second->onCreate(inputs, op, this, outputs);
if (NULL == exe) { if (NULL == exe) {
mOpFullSupport = false; mSupportDeferEncode = false;
MNN_PRINT("The Creator Don't support type [%s], %s\n", MNN::EnumNameOpType(op->type()), op->name() ? op->name()->c_str() : ""); MNN_PRINT("The Creator Don't support type [%s], %s\n", MNN::EnumNameOpType(op->type()), op->name() ? op->name()->c_str() : "");
return NULL; return NULL;
} }
@ -192,8 +226,7 @@ void MetalBackend::onExecuteBegin() const {
} }
void MetalBackend::onExecuteEnd() const { void MetalBackend::onExecuteEnd() const {
flushEncoder(); flushEncoder();
auto ctx = (__bridge MNNMetalContext *)context(); commit_net();
[ctx commit_net];
if(mFrameEncodeCache) { if(mFrameEncodeCache) {
for(auto opEncoder : mOpEncoders) { for(auto opEncoder : mOpEncoders) {
@ -202,6 +235,20 @@ void MetalBackend::onExecuteEnd() const {
setOpEncoder(); setOpEncoder();
} }
} }
bool MetalBackend::onGetTensorInfo(const Tensor* tensor, void* dstInfo) {
if (nullptr == dstInfo) {
return true;
}
auto dst = (MNNMetalTensorContent*)dstInfo;
dst->type.code = halide_type_float;
if (mUseFloatAsFp16) {
dst->type.bits = 16;
} else {
dst->type.bits = 32;
}
MNNMetalGetTensorContent(dst, (void*)tensor);
return true;
}
bool MetalBackend::isCommandEncoderSet() { bool MetalBackend::isCommandEncoderSet() {
return mOpEncoderSet;// !isCommitEachShader & mOpFullSupport return mOpEncoderSet;// !isCommitEachShader & mOpFullSupport
@ -350,14 +397,13 @@ void MetalBackend::onResizeBegin() {
// Finish last inference task if needed // Finish last inference task if needed
flushEncoder(); flushEncoder();
auto ctx = (__bridge MNNMetalContext *)context(); commit_net();
[ctx commit_net]; wait();
[ctx wait];
} }
ErrorCode MetalBackend::onResizeEnd() { ErrorCode MetalBackend::onResizeEnd() {
auto ctx = (__bridge MNNMetalContext *)context(); auto ctx = (__bridge MNNMetalContext *)context();
mFrameEncodeCache = (!ctx.isCommitEachShader && mOpFullSupport); mFrameEncodeCache = (!ctx.isCommitEachShader && mSupportDeferEncode);
return NO_ERROR; return NO_ERROR;
} }
@ -368,13 +414,17 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
auto device = (id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *) (dst->deviceId()))->getBuffer(); auto device = (id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *) (dst->deviceId()))->getBuffer();
auto floats = src->getType().code == halide_type_float; auto floats = src->getType().code == halide_type_float;
// For command queue from user, need user to make sure last frame's gpu work is ready
bool needWait = mRuntime->getCommandQueue() == nil;
// cast // cast
if (sfmt == dfmt || src->dimensions() <= 1) { if (sfmt == dfmt || src->dimensions() <= 1) {
if (floats) { if (floats && mUseFloatAsFp16) {
NSUInteger size = src->elementSize(); NSUInteger size = src->elementSize();
auto sizeC4 = UP_DIV(size, 4); auto sizeC4 = UP_DIV(size, 4);
auto host = this->getHostBuffer(sizeC4 * 4 * sizeof(float)); auto host = this->getHostBuffer(sizeC4 * 4 * sizeof(float));
[ctx wait];// make sure previous gpu task finished. for reuse mHostBuffer and mShapeH2D if (needWait) {
wait();
}
memcpy(host.contents, src->host<float>(), src->size()); memcpy(host.contents, src->host<float>(), src->size());
unsigned int limits[] = { unsigned int limits[] = {
(unsigned int)sizeC4, (unsigned int)sizeC4,
@ -383,8 +433,8 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
1 1
}; };
::memcpy(mShapeH2D.contents, limits, sizeof(limits)); ::memcpy(mShapeH2D.contents, limits, sizeof(limits));
auto encoder = [ctx encoder]; auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto bandwidth = [ctx load: @"downcast_float4" encoder:encoder]; auto bandwidth = [ctx load: @"downcast_float4" encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:host offset:0 atIndex:0]; [encoder setBuffer:host offset:0 atIndex:0];
[encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1]; [encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
@ -397,12 +447,14 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
threads.first.width = UP_DIV(threads.first.width, threads.second.width); threads.first.width = UP_DIV(threads.first.width, threads.second.width);
[encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second]; [encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second];
[encoder endEncoding]; [encoder endEncoding];
[ctx commit]; commit();
//[ctx wait]; //[ctx wait];
} else { } else {
[ctx wait]; if (needWait) {
memcpy(device.contents, src->host<uint8_t>(), src->size()); wait();
[ctx commit]; }
memcpy((uint8_t*)device.contents + TensorUtils::getDescribe(dst)->extra.offset, src->host<uint8_t>(), src->size());
commit();
//[ctx wait]; //[ctx wait];
} }
} }
@ -410,21 +462,23 @@ void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) cons
else { else {
auto buffer = getHostBuffer(src->elementSize() * sizeof(float)); auto buffer = getHostBuffer(src->elementSize() * sizeof(float));
[ctx wait];// make sure previous gpu task finished. for reuse mHostBuffer and mShapeH2D if (needWait) {
wait();
}
auto size = getTensorShape(mShapeH2D, src); auto size = getTensorShape(mShapeH2D, src);
memcpy(buffer.contents, src->host<float>(), src->size()); memcpy(buffer.contents, src->host<float>(), src->size());
auto encoder = [ctx encoder]; auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Down); auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Down);
MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt
auto bandwidth = [ctx load:kernel encoder:encoder]; auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:buffer offset:0 atIndex:0]; [encoder setBuffer:buffer offset:0 atIndex:0];
[encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1]; [encoder setBuffer:device offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
[encoder setBuffer:mShapeH2D offset:0 atIndex:2]; [encoder setBuffer:mShapeH2D offset:0 atIndex:2];
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth]; [ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
[encoder endEncoding]; [encoder endEncoding];
[ctx commit]; commit();
//[ctx wait]; //[ctx wait];
} }
} }
@ -437,14 +491,14 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
auto floats = src->getType().code == halide_type_float; auto floats = src->getType().code == halide_type_float;
// cast // cast
if (sfmt == dfmt || src->dimensions() <= 1) { if (sfmt == dfmt || src->dimensions() <= 1) {
if (floats) { if (floats && mUseFloatAsFp16) {
auto eleSize = dst->elementSize(); auto eleSize = dst->elementSize();
eleSize = UP_DIV(eleSize, 4) * 4; eleSize = UP_DIV(eleSize, 4) * 4;
auto buffer = getHostBuffer(eleSize * dst->getType().bytes()); auto buffer = getHostBuffer(eleSize * dst->getType().bytes());
NSUInteger size = src->elementSize(); NSUInteger size = src->elementSize();
auto encoder = [ctx encoder]; auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto bandwidth = [ctx load: @"upcast_float4" encoder:encoder]; auto bandwidth = [ctx load: @"upcast_float4" encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0]; [encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:buffer offset:0 atIndex:1]; [encoder setBuffer:buffer offset:0 atIndex:1];
auto sizeC4 = UP_DIV(size, 4); auto sizeC4 = UP_DIV(size, 4);
@ -465,32 +519,32 @@ void MetalBackend::onCopyDeviceToHost(const Tensor *src, const Tensor *dst) cons
[encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second]; [encoder dispatchThreadgroups:threads.first threadsPerThreadgroup:threads.second];
[encoder endEncoding]; [encoder endEncoding];
[ctx commit]; commit();
[ctx wait]; wait();
memcpy(dst->host<float>(), buffer.contents, dst->size()); memcpy(dst->host<float>(), buffer.contents, dst->size());
} else { } else {
[ctx commit]; commit();
[ctx wait]; wait();
memcpy(dst->host<uint8_t>(), device.contents, dst->size()); memcpy(dst->host<uint8_t>(), (uint8_t*)device.contents + TensorUtils::getDescribe(src)->extra.offset, dst->size());
} }
} }
// convert // convert
else { else {
auto size = getTensorShape(mShapeD2H, src); auto size = getTensorShape(mShapeD2H, src);
auto buffer = getHostBuffer(dst->size()); auto buffer = getHostBuffer(dst->size());
auto encoder = [ctx encoder]; auto encoder = [getCommandBufferForBufferCopy() computeCommandEncoder];
auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Up); auto kernel = kernelForConvert(src->getType(), sfmt, dfmt, Up);
MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt MNN_ASSERT(kernel != nil); // unsupported sfmt to dfmt
auto bandwidth = [ctx load:kernel encoder:encoder]; auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0]; [encoder setBuffer:device offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:buffer offset:0 atIndex:1]; [encoder setBuffer:buffer offset:0 atIndex:1];
[encoder setBuffer:mShapeD2H offset:0 atIndex:2]; [encoder setBuffer:mShapeD2H offset:0 atIndex:2];
[ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth]; [ctx dispatchEncoder:encoder threads:size bandwidth:bandwidth];
[encoder endEncoding]; [encoder endEncoding];
[ctx commit]; commit();
[ctx wait]; wait();
memcpy(dst->host<float>(), buffer.contents, dst->size()); memcpy(dst->host<float>(), buffer.contents, dst->size());
} }
} }
@ -499,7 +553,7 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const { id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const {
auto ctx = (__bridge MNNMetalContext *)context(); auto ctx = (__bridge MNNMetalContext *)context();
auto standalone = encoder == nil; auto standalone = encoder == nil;
encoder = encoder ?: [ctx encoder]; encoder = encoder ?: [getCommandBufferForBufferCopy() computeCommandEncoder];
auto sfmt = TensorUtils::getDescribe(src)->dimensionFormat; auto sfmt = TensorUtils::getDescribe(src)->dimensionFormat;
auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat; auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat;
@ -507,7 +561,7 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
if (sfmt == dfmt || src->dimensions() <= 1) { if (sfmt == dfmt || src->dimensions() <= 1) {
auto flt = dst->getType().code == halide_type_float; auto flt = dst->getType().code == halide_type_float;
auto size = flt ? dst->elementSize() : dst->size(); auto size = flt ? dst->elementSize() : dst->size();
auto bandwidth = [ctx load:flt ? @"copy_float" : @"copy_byte" encoder:encoder]; auto bandwidth = [ctx load:flt ? @"copy_float" : @"copy_byte" encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)src->deviceId())->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)src->deviceId())->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)dst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)dst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
[ctx dispatchEncoder:encoder threads:{(NSUInteger)size, 1, 1} bandwidth:bandwidth]; [ctx dispatchEncoder:encoder threads:{(NSUInteger)size, 1, 1} bandwidth:bandwidth];
@ -521,7 +575,7 @@ void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst,
} }
auto size = getTensorShape(shape, src); auto size = getTensorShape(shape, src);
auto bandwidth = [ctx load:kernel encoder:encoder]; auto bandwidth = [ctx load:kernel encoder:encoder fp16:mUseFloatAsFp16];
[encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(src->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0]; [encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(src->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(src)->extra.offset atIndex:0];
[encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(dst->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1]; [encoder setBuffer:( id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)(dst->buffer().device))->getBuffer() offset:TensorUtils::getDescribe(dst)->extra.offset atIndex:1];
[encoder setBuffer:shape offset:0 atIndex:2]; [encoder setBuffer:shape offset:0 atIndex:2];
@ -538,16 +592,15 @@ void MetalBackend::onCopyBuffer(const Tensor *src, const Tensor *dst) const {
flushEncoder(); flushEncoder();
auto ctx = (__bridge MNNMetalContext *)context(); auto ctx = (__bridge MNNMetalContext *)context();
if(!mFrameEncodeCache) { if(!mFrameEncodeCache) {
[ctx commit_net]; commit_net();
} }
onCopyBuffer(src, dst, nil, nil); onCopyBuffer(src, dst, nil, nil);
} }
id<MTLComputeCommandEncoder> MetalBackend::encoder() const { id<MTLComputeCommandEncoder> MetalBackend::encoder_for_net() const {
if (nil == mComputeEncoder) { if (nil == mComputeEncoder) {
auto ctx = (__bridge MNNMetalContext *)context(); mComputeEncoder = encoder_net();//TO DO :: use which cmdBuffer
mComputeEncoder = [ctx encoder_net];//TO DO :: use which cmdBuffer
} }
return mComputeEncoder; return mComputeEncoder;
} }
@ -570,14 +623,99 @@ void MetalBackend::onCopyBuffer(const Tensor *src, const Tensor *dst, id<MTLComp
int MetalBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) { int MetalBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) {
flushEncoder(); flushEncoder();
auto ctx = (__bridge MNNMetalContext *)context(); auto ctx = (__bridge MNNMetalContext *)context();
[ctx commit_net]; commit_net();
if (toCpu) { if (toCpu) {
[ctx wait]; wait();
} }
mFrameEncodeCache = false; mFrameEncodeCache = false;
mOpEncoderSet = false; mOpEncoderSet = false;
return 0; return 0;
} }
id<MTLCommandBuffer> MetalBackend::getCommandBufferForBufferCopy() const {
if (nil == _commandBuffer) {
_commandBuffer = [_commandQueue commandBuffer];
if (!mSupportDeferEncode) {
// In this case _commandBuffer should be the same as _commandBuffer_net
_commandBuffer_net = _commandBuffer;
}
}
return _commandBuffer;
}
id<MTLCommandBuffer> MetalBackend::getCommandBufferForNet() const {
if (nil == _commandBuffer_net) {
_commandBuffer_net = [_commandQueue commandBuffer];
if (!mSupportDeferEncode) {
// In this case _commandBuffer should be the same as _commandBuffer_net
_commandBuffer = _commandBuffer_net;
}
}
return _commandBuffer_net;
}
void MetalBackend::commit() const {
if (nil != _commandBuffer && _commandBuffer.status < MTLCommandBufferStatusCommitted) {
[_commandBuffer commit];
_waiting = _commandBuffer;
_commandBuffer = nil;
if (!mSupportDeferEncode) {
// In this case _commandBuffer should be the same as _commandBuffer_net
_commandBuffer_net = nil;
}
}
}
void MetalBackend::commit_net() const {
if (nil != _commandBuffer_net && _commandBuffer_net.status < MTLCommandBufferStatusCommitted) {
[_commandBuffer_net commit];
_waiting = _commandBuffer_net;
_commandBuffer_net = nil;
if (!mSupportDeferEncode) {
// In this case _commandBuffer should be the same as _commandBuffer_net
_commandBuffer = nil;
}
}
}
void MetalBackend::wait() const {
if (nil != _waiting) {
auto buffer = _waiting;
if (buffer.status >= MTLCommandBufferStatusCompleted) {
return;
}
#if MNN_METAL_BENCHMARK
NSTimeInterval begin = [NSDate timeIntervalSinceReferenceDate];
[buffer waitUntilCompleted];
NSTimeInterval end = [NSDate timeIntervalSinceReferenceDate];
if (@available(iOS 10.3, *)) {
printf("[METAL] commit costs: %.3fms\t(kernel: %.3fms, GPU: %.3fms)\n", (end - begin) * 1000.f,
(buffer.kernelEndTime - buffer.kernelStartTime) * 1000.f,
(buffer.GPUEndTime - buffer.GPUStartTime) * 1000.f);
} else {
printf("[METAL] commit costs: %.3fms\n", (end - begin) * 1000.f);
}
#else
[buffer waitUntilCompleted];
#endif
#if MNN_METAL_DEBUG
if (buffer.error) {
printf("[METAL] %s\n", buffer.error.localizedDescription.UTF8String);
}
#endif
}
_waiting = nil;
}
id<MTLComputePipelineState> MetalBackend::makeComputePipelineWithSourceOption(const char* csource, const char* cname, MTLCompileOptions *options) const{
auto ctx = (__bridge MNNMetalContext *)context();
auto source = [[NSString alloc] initWithUTF8String:csource];
auto name = [[NSString alloc] initWithUTF8String:cname];
return [ctx pipelineWithSourceOption:source name:name options:options];
}
void MetalRuntime::setCommandQueue(id<MTLCommandQueue> queue) {
mQueue = queue;
}
void MetalRuntime::setGpuMode(const int mode_num) { void MetalRuntime::setGpuMode(const int mode_num) {
int totalSet = 0; int totalSet = 0;
@ -642,9 +780,6 @@ MetalRuntime* MetalRuntime::create(const Backend::Info& info, id<MTLDevice> devi
if (nil == sharedContext.device) { if (nil == sharedContext.device) {
sharedContext.device = device; sharedContext.device = device;
} }
if (nil == sharedContext.queue) {
sharedContext.queue = [sharedContext.device newCommandQueue];
}
auto mContext = (__bridge_retained void *)[[MNNMetalContext alloc] init]; auto mContext = (__bridge_retained void *)[[MNNMetalContext alloc] init];
auto ctx = (__bridge MNNMetalContext *)mContext; auto ctx = (__bridge MNNMetalContext *)mContext;
BOOL res = [ctx initWithSharedContext:&sharedContext dev:device]; BOOL res = [ctx initWithSharedContext:&sharedContext dev:device];
@ -654,6 +789,18 @@ MetalRuntime* MetalRuntime::create(const Backend::Info& info, id<MTLDevice> devi
} }
auto rt = new MetalRuntime(mContext); auto rt = new MetalRuntime(mContext);
rt->setGpuMode(info.gpuMode); rt->setGpuMode(info.gpuMode);
if (nil != sharedContext.queue) {
rt->setCommandQueue(sharedContext.queue);
}
#ifdef MNN_METAL_TEST
else {
id<MTLCommandQueue> queue = [sharedContext.device newCommandQueue];
rt->setCommandQueue(queue);
}
#endif
if (nullptr != info.user) {
rt->mDefaultConfig = *info.user;
}
return rt; return rt;
} }
@ -833,7 +980,12 @@ bool MetalRuntime::onMeasure(const std::vector<Tensor*>& inputs, const std::vect
} }
Backend* MetalRuntime::onCreate(const BackendConfig* config) const { Backend* MetalRuntime::onCreate(const BackendConfig* config) const {
return new MetalBackend(mStatic, this); BackendConfig::PrecisionMode precision = mDefaultConfig.precision;
if (nullptr != config) {
precision = config->precision;
}
bool useFp16AsFp32 = precision != BackendConfig::Precision_High;
return new MetalBackend(mStatic, this, useFp16AsFp32);
} }
void MetalRuntime::onGabageCollect(int level) { void MetalRuntime::onGabageCollect(int level) {
@ -895,6 +1047,9 @@ void registerMetalRuntimeCreator() {
id<MTLDevice> device = MTLCreateSystemDefaultDevice(); id<MTLDevice> device = MTLCreateSystemDefaultDevice();
if (nil != device) { if (nil != device) {
registerMetalOps(); registerMetalOps();
#ifdef MNN_SUPPORT_RENDER
registerMetalRenderOps();
#endif
MNNInsertExtraRuntimeCreator(MNN_FORWARD_METAL, new MetalRuntimeCreator(device), false); MNNInsertExtraRuntimeCreator(MNN_FORWARD_METAL, new MetalRuntimeCreator(device), false);
} else { } else {
MNN_ERROR("Init Metal Error\n"); MNN_ERROR("Init Metal Error\n");

View File

@ -9,17 +9,16 @@
#ifndef MetalBinary_hpp #ifndef MetalBinary_hpp
#define MetalBinary_hpp #define MetalBinary_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MetalDefine.h"
#include <string> #include <string>
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalBinary : public Execution { class MetalBinary : public MetalExecution {
public: public:
MetalBinary(Backend *backend, std::string type, const MNN::Op *op); MetalBinary(Backend *backend, std::string type, const MNN::Op *op);
virtual ~MetalBinary() = default; virtual ~MetalBinary() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private: private:

View File

@ -14,13 +14,13 @@
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
MetalBinary::MetalBinary(Backend *backend, std::string type, const MNN::Op *op) : Execution(backend) { MetalBinary::MetalBinary(Backend *backend, std::string type, const MNN::Op *op) : MetalExecution(backend) {
auto mKernelName = "binary_" + type + "_x1"; auto mKernelName = "binary_" + type + "_x1";
auto mtbn = static_cast<MetalBackend *>(backend); auto mtbn = static_cast<MetalBackend *>(backend);
auto context = (__bridge MNNMetalContext *)mtbn->context(); auto context = (__bridge MNNMetalContext *)mtbn->context();
mConstBuffer = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; mConstBuffer = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
auto kn = [NSString stringWithCString:mKernelName.c_str() encoding:[NSString defaultCStringEncoding]]; auto kn = [NSString stringWithCString:mKernelName.c_str() encoding:[NSString defaultCStringEncoding]];
mPipeline = [context pipelineWithName:kn]; mPipeline = [context pipelineWithName:kn fp16:mtbn->useFp16InsteadFp32()];
mActivationType = op->main_as_BinaryOp()->activationType(); mActivationType = op->main_as_BinaryOp()->activationType();
} }
ErrorCode MetalBinary::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode MetalBinary::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
@ -39,32 +39,14 @@ ErrorCode MetalBinary::onResize(const std::vector<Tensor *> &inputs, const std::
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalBinary::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalBinary::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto input0 = inputs[0], input1 = inputs[1], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
if(backend->isCommandEncoderSet()) { [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
return NO_ERROR; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
} [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
auto func = [=](){ [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto input0 = inputs[0], input1 = inputs[1], output = outputs[0];
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
#define CHECK(t, i) if (originOp == t) return i; #define CHECK(t, i) if (originOp == t) return i;

View File

@ -9,22 +9,23 @@
#ifndef MetalCast_hpp #ifndef MetalCast_hpp
#define MetalCast_hpp #define MetalCast_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MetalDefine.h"
#import "Type_generated.h" #import "Type_generated.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalCast : public Execution { class MetalCast : public MetalExecution {
public: public:
MetalCast(Backend *backend, DataType srcType, DataType dstType); MetalCast(Backend *backend, id<MTLComputePipelineState> pipeline);
virtual ~MetalCast() = default; virtual ~MetalCast() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private: private:
DataType mSrcType; id<MTLBuffer> mConstBuffer;
DataType mDstType; id<MTLComputePipelineState> mPipeline;
std::pair<MTLSize, MTLSize> mThreads;
}; };
} // namespace MNN } // namespace MNN

View File

@ -13,40 +13,55 @@
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
static const char* gCastTemplate =
R"glsl(
#include <metal_stdlib>
using namespace metal;
kernel void main0(const device T0 *in [[buffer(0)]],
device T1 *out [[buffer(1)]],
device uint4& s [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x < (uint)s.x) {
int off = gid.x;
T0 x = in[off];
T1 y;
y.x = x.x;
y.y = x.y;
y.z = x.z;
y.w = x.w;
TRANSOFRM;
out[off] = y;
}
}
)glsl";
MetalCast::MetalCast(Backend *backend, DataType srcType, DataType dstType) MetalCast::MetalCast(Backend *backend, id<MTLComputePipelineState> pipeline)
: Execution(backend), mSrcType(srcType), mDstType(dstType) { : MetalExecution(backend) {
// nothing to do auto mtbn = static_cast<MetalBackend *>(backend);
auto context = (__bridge MNNMetalContext *)mtbn->context();
mPipeline = pipeline;
mConstBuffer = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
}
ErrorCode MetalCast::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto mtbn = static_cast<MetalBackend *>(backend());
auto context = (__bridge MNNMetalContext *)mtbn->context();
auto input = inputs[0];
auto element = input->elementSize();
auto sizeDiv4 = UP_DIV(element, 4);
((int *)mConstBuffer.contents)[0] = sizeDiv4;
mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(sizeDiv4, 1, 1)];
return NO_ERROR;
} }
ErrorCode MetalCast::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalCast::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context(); auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0], output = outputs[0]; auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
NSString *kernel = nil;
if (mSrcType == DataType_DT_FLOAT && mDstType == DataType_DT_INT32) {
kernel = @"cast_float_to_int32";
} else if (mSrcType == DataType_DT_INT32 && mDstType == DataType_DT_FLOAT) {
kernel = @"cast_int32_to_float";
} else if (mSrcType == DataType_DT_UINT8 && mDstType == DataType_DT_FLOAT) {
kernel = @"cast_uint8_to_float";
} else if (mSrcType == DataType_DT_UINT8 && mDstType == DataType_DT_INT32) {
kernel = @"cast_uint8_to_int";
} else if (mSrcType == DataType_DT_FLOAT && mDstType == DataType_DT_UINT8) {
kernel = @"cast_float_to_uint8";
} else {
return NOT_SUPPORT;
}
auto encoder = backend->encoder();
auto bandwidth = [context load:kernel encoder:encoder];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[context dispatchEncoder:encoder [encoder setBuffer:mConstBuffer offset:0 atIndex:2];
threads:{ (NSUInteger) output->elementSize(), (NSUInteger)1, (NSUInteger)1 } [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
bandwidth:bandwidth];
return NO_ERROR;
} }
static DataType _mapDataType(DataType src) { static DataType _mapDataType(DataType src) {
if (DataType_DT_BOOL == src) { if (DataType_DT_BOOL == src) {
@ -63,27 +78,88 @@ static DataType _mapDataType(DataType src) {
class MetalCastCreator : public MetalBackend::Creator { class MetalCastCreator : public MetalBackend::Creator {
public: public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const { virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
auto cast = op->main_as_CastParam(); auto mtbn = static_cast<MetalBackend *>(backend);
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
NSString* T0 = nil;
NSString* T1 = nil;
NSString* TRANSOFRM = @"";
auto dstT = op->main_as_CastParam()->dstT();
if (dstT == DataType_DT_BOOL) {
TRANSOFRM = @"y=select(int4(0),int4(1),y>0);";
}
auto dstType = _mapDataType(dstT);
bool useFp16 = mtbn->useFp16InsteadFp32();
switch (dstType) {
case DataType_DT_FLOAT:
if (useFp16) {
T1 = @"half4";
} else {
T1 = @"float4";
}
break;
case DataType_DT_INT8:
T1 = @"char4";
break;
case DataType_DT_UINT8:
T1 = @"uchar4";
break;
case DataType_DT_INT32:
T1 = @"int4";
break;
default:
MNN_ERROR("Don't support cast dst : %d\n", dstType);
return nullptr;
break;
}
auto srcType = inputs[0]->getType(); auto srcType = inputs[0]->getType();
auto dst = _mapDataType(cast->dstT()); switch (srcType.code) {
case halide_type_float:
if (useFp16) {
T0 = @"half4";
} else {
T0 = @"float4";
}
break;
case halide_type_int:
{
if (srcType.bits == 32) {
T0 = @"int4";
} else if (srcType.bits == 8) {
T0 = @"char4";
} else {
MNN_ERROR("Don't support cast src : %d\n", srcType.code);
return nullptr;
}
break;
}
case halide_type_uint:
{
if (srcType.bits == 32) {
T0 = @"uint4";
} else if (srcType.bits == 8) {
T0 = @"uchar4";
} else {
MNN_ERROR("Don't support cast src : %d\n", srcType.code);
return nullptr;
}
break;
}
default:
MNN_ERROR("Don't support cast src : %d\n", srcType.code);
return nullptr;
}
if (srcType.code == halide_type_float && dst == DataType_DT_INT32) { compileOptions.preprocessorMacros = @{
return new MetalCast(backend, DataType_DT_FLOAT, dst); @"T0" : T0,
@"T1" : T1,
@"TRANSOFRM" : TRANSOFRM
};
auto pipeline = mtbn->makeComputePipelineWithSourceOption(gCastTemplate, "main0", compileOptions);
if (nil == pipeline) {
MNN_ERROR("Create Cast execution error for metal\n");
return nullptr;
} }
if (srcType.code == halide_type_int && srcType.bits == 32 && dst == DataType_DT_FLOAT) { return new MetalCast(backend, pipeline);
return new MetalCast(backend, DataType_DT_INT32, dst);
}
if (srcType.code == halide_type_float && dst == DataType_DT_UINT8) {
return new MetalCast(backend, DataType_DT_FLOAT, dst);
}
if (srcType.code == halide_type_uint && srcType.bits == 8 && dst == DataType_DT_FLOAT) {
return new MetalCast(backend, DataType_DT_UINT8, dst);
}
if (srcType.code == halide_type_uint && srcType.bits == 8 && dst == DataType_DT_INT32) {
return new MetalCast(backend, DataType_DT_UINT8, dst);
}
MNN_PRINT("%d, %d - %d\n", srcType.code, srcType.bits, dst);
return NULL;
} }
}; };
REGISTER_METAL_OP_CREATOR(MetalCastCreator, OpType_Cast); REGISTER_METAL_OP_CREATOR(MetalCastCreator, OpType_Cast);

View File

@ -3,14 +3,15 @@ import sys
from os import listdir from os import listdir
from os.path import isfile, join from os.path import isfile, join
import makeshader import makeshader
shaderPath=sys.argv[1] metalSourcePath=sys.argv[1]
cppPath= shaderPath + "/MetalOPRegister.mm" renderPath = os.path.join(metalSourcePath, "render")
cppPath= os.path.join(metalSourcePath, "MetalOPRegister.mm")
cppRenderPath = os.path.join(renderPath, 'MetalRenderOpRegister.mm')
def genRegister(): def genRegister():
shaders=[] shaders=[]
for root, dirs, files in os.walk(shaderPath): for file in os.listdir(metalSourcePath):
for file in files: if file.endswith('.mm'):
if file.endswith('.mm'): shaders.append(os.path.join(metalSourcePath,file))
shaders.append(os.path.join(root,file))
with open(cppPath,"w") as f: with open(cppPath,"w") as f:
f.write("// This file is generated by Shell for ops register\n") f.write("// This file is generated by Shell for ops register\n")
f.write("#import \"backend/metal/MetalDefine.h\"\n") f.write("#import \"backend/metal/MetalDefine.h\"\n")
@ -31,19 +32,48 @@ def genRegister():
for func in funcs: for func in funcs:
f.write(" "+func+"\n") f.write(" "+func+"\n")
f.write("}\n#endif\n}") f.write("}\n#endif\n}")
if os.path.isdir(renderPath):
shaders=[]
for file in os.listdir(renderPath):
if file.endswith('.mm'):
shaders.append(os.path.join(renderPath,file))
with open(cppRenderPath,"w") as f:
f.write("// This file is generated by Shell for ops register\n")
f.write("#import \"backend/metal/MetalDefine.h\"\n")
f.write(" namespace MNN {\n")
f.write("#if MNN_METAL_ENABLED\n")
funcs=[]
for shapath in shaders:
with open(shapath,"r") as sha:
lines=sha.readlines()
for l in lines:
if l.startswith("REGISTER_METAL_OP_CREATOR("):
x=l.replace("REGISTER_METAL_OP_CREATOR(","").replace(")","").replace(" ","").replace(";","").replace("\n","").split(",")
funcname="___"+x[0]+"__"+x[1]+"__();"
funcs.append(funcname)
f.write(" extern void "+funcname+"\n")
pass
f.write("void registerMetalRenderOps() {\n")
for func in funcs:
f.write(" "+func+"\n")
f.write("}\n#endif\n}")
def genSchema(): def genSchema():
FLATC = shaderPath + "/../../../3rd_party/flatbuffers/tmp/flatc" FLATC = metalSourcePath + "/../../../3rd_party/flatbuffers/tmp/flatc"
sourceFile = shaderPath + "/schema/MetalCache.fbs" sourceFile = metalSourcePath + "/schema/MetalCache.fbs"
destFile = shaderPath + "/" destFile = metalSourcePath + "/"
cmd = FLATC + " -c " + sourceFile +" --gen-object-api" +" --reflect-names" cmd = FLATC + " -c " + sourceFile +" --gen-object-api" +" --reflect-names"
print(cmd) print(cmd)
print(os.popen(cmd).read()) print(os.popen(cmd).read())
return return
def genShader(): def genShader():
if os.path.isdir(renderPath):
print("Has Render")
shaders = makeshader.findAllShader("render/shader")
makeshader.generateFile(os.path.join(renderPath, "AllRenderShader.hpp"), os.path.join(renderPath, "AllRenderShader.cpp"), shaders)
shaders = makeshader.findAllShader("shader") shaders = makeshader.findAllShader("shader")
makeshader.generateFile("AllShader.hpp", "AllShader.cpp", shaders) makeshader.generateFile(os.path.join(metalSourcePath, "AllShader.hpp"), os.path.join(metalSourcePath, "AllShader.cpp"), shaders)
if __name__ == '__main__': if __name__ == '__main__':
genRegister() genRegister()

View File

@ -21,7 +21,7 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
protected: protected:
virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override; virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
private: private:
std::string mParam; std::string mParam;

View File

@ -10,7 +10,6 @@
#import "core/Macro.h" #import "core/Macro.h"
#import "backend/metal/MetalBackend.hpp" #import "backend/metal/MetalBackend.hpp"
#import "backend/metal/MetalConvolution1x1.hpp" #import "backend/metal/MetalConvolution1x1.hpp"
#import "backend/metal/MetalConvolutionGEMM.hpp"
#import "backend/metal/MetalConvolutionWinograd.hpp" #import "backend/metal/MetalConvolutionWinograd.hpp"
#include <string> #include <string>
@ -26,6 +25,7 @@ ErrorCode MetalConvolution::onResize(const std::vector<Tensor *> &inputs, const
// prepare // prepare
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto mtbn = backend;
auto context = (__bridge MNNMetalContext *)backend->context(); auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0]; auto input = inputs[0];
auto output = outputs[0]; auto output = outputs[0];
@ -92,13 +92,13 @@ ErrorCode MetalConvolution::onResize(const std::vector<Tensor *> &inputs, const
NSUInteger gid_y = oh; NSUInteger gid_y = oh;
NSUInteger gid_z = UP_DIV(oc_4, packC) * ob; NSUInteger gid_z = UP_DIV(oc_4, packC) * ob;
mPipeline = [context pipelineWithName:kernelName]; mPipeline = [context pipelineWithName:kernelName fp16:backend->useFp16InsteadFp32()];
NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer(), NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer(),
(id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(), (id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(),
mConstBuffer, mWeight, mBias, nil]; mConstBuffer, mWeight, mBias, nil];
std::string name = [kernelName UTF8String] + mParam; std::string name = [kernelName UTF8String] + mParam;
auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name]; auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name queue:backend->queue()];
mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret)); mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret));
} else { } else {
const int total_kernel = 5; const int total_kernel = 5;
@ -130,13 +130,13 @@ ErrorCode MetalConvolution::onResize(const std::vector<Tensor *> &inputs, const
mConstBuffer, mWeight, mBias, nil]; mConstBuffer, mWeight, mBias, nil];
for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
id<MTLComputePipelineState> pipeline = [context pipelineWithName:shaderName[knl_idx]]; id<MTLComputePipelineState> pipeline = [context pipelineWithName:shaderName[knl_idx] fp16:mtbn->useFp16InsteadFp32()];
NSUInteger gid_x = UP_DIV(ow, itemW[knl_idx]); NSUInteger gid_x = UP_DIV(ow, itemW[knl_idx]);
NSUInteger gid_y = UP_DIV(oh, itemH[knl_idx]); NSUInteger gid_y = UP_DIV(oh, itemH[knl_idx]);
NSUInteger gid_z = UP_DIV(oc_4, itemC[knl_idx]) * ob; NSUInteger gid_z = UP_DIV(oc_4, itemC[knl_idx]) * ob;
std::string name = [shaderName[knl_idx] UTF8String] + mParam; std::string name = [shaderName[knl_idx] UTF8String] + mParam;
auto ret = [context getGridAndThreadgroup:pipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name]; auto ret = [context getGridAndThreadgroup:pipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name queue:backend->queue()];
if(min_cost.first > std::get<2>(ret)) { if(min_cost.first > std::get<2>(ret)) {
min_cost.first = std::get<2>(ret); min_cost.first = std::get<2>(ret);
@ -148,45 +148,24 @@ ErrorCode MetalConvolution::onResize(const std::vector<Tensor *> &inputs, const
// printf("conv idx:%d, min_cost:%d\n", (int)min_cost.second, (int)min_cost.first); // printf("conv idx:%d, min_cost:%d\n", (int)min_cost.second, (int)min_cost.first);
// std::string tmp = [shaderName[min_cost.second] UTF8String]; // std::string tmp = [shaderName[min_cost.second] UTF8String];
// printf("!!~ %s\n", tmp.c_str()); // printf("!!~ %s\n", tmp.c_str());
mPipeline = [context pipelineWithName:shaderName[min_cost.second]]; mPipeline = [context pipelineWithName:shaderName[min_cost.second] fp16:mtbn->useFp16InsteadFp32()];
} }
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalConvolution::onFloat(const Tensor *input, const Tensor *output) { void MetalConvolution::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto oc_4 = UP_DIV(output->channel(), 4);
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCommandEncoderSet()) { auto bandwidth = (MetalBandwidth){mPipeline.threadExecutionWidth, mPipeline.maxTotalThreadsPerThreadgroup, NO};
return NO_ERROR;
}
auto func = [=](){ [encoder setComputePipelineState:mPipeline];
auto oc_4 = UP_DIV(output->channel(), 4); [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
auto encoder = backend->encoder(); [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[encoder setBuffer:mWeight offset:0 atIndex:3];
[encoder setBuffer:mBias offset:0 atIndex:4];
auto bandwidth = (MetalBandwidth){mPipeline.threadExecutionWidth, mPipeline.maxTotalThreadsPerThreadgroup, NO}; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[encoder setBuffer:mWeight offset:0 atIndex:3];
[encoder setBuffer:mBias offset:0 atIndex:4];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
//need to commit
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalConvolutionCreator : public MetalBackend::Creator { class MetalConvolutionCreator : public MetalBackend::Creator {
@ -207,9 +186,6 @@ public:
if (MetalConvolutionWinograd::isValid(conv, inputs[0], outputs[0])) { if (MetalConvolutionWinograd::isValid(conv, inputs[0], outputs[0])) {
return new MetalConvolutionWinograd(backend, input, op); return new MetalConvolutionWinograd(backend, input, op);
} }
if (MetalConvolutionGEMM::isValid(conv, input)) {
return new MetalConvolutionGEMM(backend, input, op);
}
if (MetalConvolution1x1::isValid(conv, input)) { if (MetalConvolution1x1::isValid(conv, input)) {
return new MetalConvolution1x1(backend, op); return new MetalConvolution1x1(backend, op);
} }

View File

@ -22,7 +22,7 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
protected: protected:
virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override; virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
private: private:
id<MTLComputePipelineState> mPipeline; id<MTLComputePipelineState> mPipeline;
std::pair<MTLSize, MTLSize> mThreads; std::pair<MTLSize, MTLSize> mThreads;

View File

@ -57,7 +57,7 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
NSUInteger gid_y = oc_4; NSUInteger gid_y = oc_4;
NSUInteger gid_z = ob; NSUInteger gid_z = ob;
mPipeline = [context pipelineWithName:@"conv1x1_g1z8"]; mPipeline = [context pipelineWithName:@"conv1x1_g1z8" fp16:backend->useFp16InsteadFp32()];
NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer(), NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer(),
(id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(), (id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(),
@ -65,14 +65,14 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
std::string name = "conv1x1_g1z8"; std::string name = "conv1x1_g1z8";
MetalRuntime *rt = (MetalRuntime *)backend->runtime(); MetalRuntime *rt = (MetalRuntime *)backend->runtime();
auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name]; auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name queue:backend->queue()];
mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret)); mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret));
} else { } else {
NSUInteger gid_x = UP_DIV(ow * oh, 4); NSUInteger gid_x = UP_DIV(ow * oh, 4);
NSUInteger gid_y = oc_4; NSUInteger gid_y = oc_4;
NSUInteger gid_z = ob; NSUInteger gid_z = ob;
mPipeline = [context pipelineWithName:@"conv1x1_g1z4"]; mPipeline = [context pipelineWithName:@"conv1x1_g1z4" fp16:backend->useFp16InsteadFp32()];
NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer(), NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer(),
(id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(), (id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId()))->getBuffer(),
@ -80,7 +80,7 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
std::string name = "conv1x1_g1z4"; std::string name = "conv1x1_g1z4";
MetalRuntime *rt = (MetalRuntime *)backend->runtime(); MetalRuntime *rt = (MetalRuntime *)backend->runtime();
auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name]; auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name queue:backend->queue()];
mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret)); mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret));
//printf("conv1x1_z4, %d %d %d %d\n", ow, oh, oc_4, ic_4); //printf("conv1x1_z4, %d %d %d %d\n", ow, oh, oc_4, ic_4);
} }
@ -100,13 +100,13 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
mConstBuffer, mWeight, mBias, nil]; mConstBuffer, mWeight, mBias, nil];
for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
id<MTLComputePipelineState> pipeline = [context pipelineWithName:shaderName[knl_idx]]; id<MTLComputePipelineState> pipeline = [context pipelineWithName:shaderName[knl_idx] fp16:backend->useFp16InsteadFp32()];
NSUInteger gid_x = UP_DIV(ow, itemW[knl_idx]); NSUInteger gid_x = UP_DIV(ow, itemW[knl_idx]);
NSUInteger gid_y = UP_DIV(oh, itemH[knl_idx]); NSUInteger gid_y = UP_DIV(oh, itemH[knl_idx]);
NSUInteger gid_z = ob * UP_DIV(oc, itemC[knl_idx]); NSUInteger gid_z = ob * UP_DIV(oc, itemC[knl_idx]);
std::string name = [shaderName[knl_idx] UTF8String]; std::string name = [shaderName[knl_idx] UTF8String];
auto ret = [context getGridAndThreadgroup:pipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name]; auto ret = [context getGridAndThreadgroup:pipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name queue:backend->queue()];
if(min_cost.first > std::get<2>(ret)) { if(min_cost.first > std::get<2>(ret)) {
min_cost.first = std::get<2>(ret); min_cost.first = std::get<2>(ret);
@ -116,39 +116,20 @@ ErrorCode MetalConvolution1x1::onResize(const std::vector<Tensor *> &inputs, con
//printf("conv1x1 idx:%d, global:%d %d %d, local:%d %d %d, min_cost:%d\n", knl_idx, (int)retTune.second.first.width, (int)retTune.second.first.height, (int)retTune.second.first.depth, (int)retTune.second.second.width, (int)retTune.second.second.height, (int)retTune.second.second.depth, (int)retTune.first); //printf("conv1x1 idx:%d, global:%d %d %d, local:%d %d %d, min_cost:%d\n", knl_idx, (int)retTune.second.first.width, (int)retTune.second.first.height, (int)retTune.second.first.depth, (int)retTune.second.second.width, (int)retTune.second.second.height, (int)retTune.second.second.depth, (int)retTune.first);
} }
//printf("conv1x1 idx:%d, min_cost:%d\n", (int)min_cost.second, (int)min_cost.first); //printf("conv1x1 idx:%d, min_cost:%d\n", (int)min_cost.second, (int)min_cost.first);
mPipeline = [context pipelineWithName:shaderName[min_cost.second]]; mPipeline = [context pipelineWithName:shaderName[min_cost.second] fp16:backend->useFp16InsteadFp32()];
} }
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalConvolution1x1::onFloat(const Tensor *input, const Tensor *output) { void MetalConvolution1x1::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); [encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
if(backend->isCommandEncoderSet()) { [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
return NO_ERROR; [encoder setBuffer:mConstBuffer offset:0 atIndex:2];
} [encoder setBuffer:mWeight offset:0 atIndex:3];
[encoder setBuffer:mBias offset:0 atIndex:4];
auto func = [=](){ [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[encoder setBuffer:mWeight offset:0 atIndex:3];
[encoder setBuffer:mBias offset:0 atIndex:4];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
} // namespace MNN } // namespace MNN
#endif /* MNN_METAL_ENABLED */ #endif /* MNN_METAL_ENABLED */

View File

@ -11,21 +11,22 @@
#import "core/ConvolutionCommon.hpp" #import "core/ConvolutionCommon.hpp"
#import "MetalBackend.hpp" #import "MetalBackend.hpp"
#import "MetalExecution.hpp"
#import "MNNMetalContext.h" #import "MNNMetalContext.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalConvolutionCommon : public Execution { class MetalConvolutionCommon : public MetalExecution {
public: public:
MetalConvolutionCommon(Backend *backend, const MNN::Op *op); MetalConvolutionCommon(Backend *backend, const MNN::Op *op);
virtual ~MetalConvolutionCommon() = default; virtual ~MetalConvolutionCommon() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
protected: protected:
void loadWeight(const MNN::Convolution2D *conv); void loadWeight(const MNN::Convolution2D *conv);
virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) = 0; virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) = 0;
virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src); virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src);
private: private:

View File

@ -16,23 +16,32 @@
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
static id<MTLBuffer> biasForConv(MNNMetalContext *context, const Convolution2D *conv) { static id<MTLBuffer> biasForConv(MNNMetalContext *context, const Convolution2D *conv, bool fp16) {
auto bias = conv->bias(); auto bias = conv->bias();
auto oc = conv->common()->outputCount(); auto oc = conv->common()->outputCount();
auto bias_size = UP_DIV(oc, 16) * 16 * sizeof(metal_float); int bytes = 4;
if (fp16) {
bytes = 2;
}
auto bias_size = UP_DIV(oc, 16) * 16 *bytes;
auto buffer = [context newDeviceBuffer:bias_size access:CPUWriteOnly]; auto buffer = [context newDeviceBuffer:bias_size access:CPUWriteOnly];
auto src = bias->data(); auto src = bias->data();
auto dst = (metal_float *)buffer.contents; ::memset(buffer.contents, 0, bias_size);
::memset(dst, 0, bias_size); if (fp16) {
#pragma clang loop vectorize(enable) unroll(enable) auto dst = (__fp16 *)buffer.contents;
for (int i = 0; i < oc; i++) { #pragma clang loop vectorize(enable) unroll(enable)
dst[i] = src[i]; for (int i = 0; i < oc; i++) {
dst[i] = src[i];
}
} else {
::memcpy(buffer.contents, src, oc * sizeof(float));
} }
return buffer; return buffer;
} }
MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *op) : Execution(backend) { MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *op) : MetalExecution(backend) {
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
auto mtbn = static_cast<MetalBackend*>(backend);
auto conv = op->main_as_Convolution2D(); auto conv = op->main_as_Convolution2D();
auto common = conv->common(); auto common = conv->common();
mOp = op; mOp = op;
@ -47,7 +56,7 @@ MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *
mStrideY = common->strideY(); mStrideY = common->strideY();
mDilateX = common->dilateX(); mDilateX = common->dilateX();
mDilateY = common->dilateY(); mDilateY = common->dilateY();
mBias = biasForConv(context, conv); mBias = biasForConv(context, conv, mtbn->useFp16InsteadFp32());
mActivationType = common->relu() ? 1 : (common->relu6() ? 2 : 0); mActivationType = common->relu() ? 1 : (common->relu6() ? 2 : 0);
} }
@ -55,8 +64,8 @@ ErrorCode MetalConvolutionCommon::onResize(const std::vector<Tensor *> &inputs,
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalConvolutionCommon::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalConvolutionCommon::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
return onFloat(inputs[0], outputs[0]); return onFloat(inputs[0], outputs[0], encoder);
} }
template <typename FType, typename TType> template <typename FType, typename TType>
@ -103,8 +112,10 @@ void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv) {
id<MTLBuffer> MetalConvolutionCommon::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) { id<MTLBuffer> MetalConvolutionCommon::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
return weightInBlock<float, metal_float>(context, group, oc, ic, kh, kw, src); if (backend->useFp16InsteadFp32()) {
return weightInBlock<float, __fp16>(context, group, oc, ic, kh, kw, src);
}
return weightInBlock<float, float>(context, group, oc, ic, kh, kw, src);
} }
id<MTLBuffer> MetalConvolutionCommon::weightForConv(const Convolution2D *conv, ConvolutionCommon::Int8Common *qnt, id<MTLBuffer> MetalConvolutionCommon::weightForConv(const Convolution2D *conv, ConvolutionCommon::Int8Common *qnt,

View File

@ -21,7 +21,7 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
protected: protected:
virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override; virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override; virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override;
private: private:
id<MTLComputePipelineState> mPipeline; id<MTLComputePipelineState> mPipeline;

View File

@ -62,7 +62,7 @@ ErrorCode MetalConvolutionDepthwise::onResize(const std::vector<Tensor *> &input
::memcpy(mConstBuffer.contents, constants, sizeof(constants)); ::memcpy(mConstBuffer.contents, constants, sizeof(constants));
auto context = (__bridge MNNMetalContext *)backend->context(); auto context = (__bridge MNNMetalContext *)backend->context();
mPipeline = [context pipelineWithName:@"conv_depthwise"]; mPipeline = [context pipelineWithName:@"conv_depthwise" fp16:backend->useFp16InsteadFp32()];
NSUInteger gid_x = ow; NSUInteger gid_x = ow;
NSUInteger gid_y = oh; NSUInteger gid_y = oh;
@ -74,39 +74,21 @@ ErrorCode MetalConvolutionDepthwise::onResize(const std::vector<Tensor *> &input
std::string name = "conv_depthwise"; std::string name = "conv_depthwise";
MetalRuntime *rt = (MetalRuntime *)backend->runtime(); MetalRuntime *rt = (MetalRuntime *)backend->runtime();
auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name]; auto ret = [context getGridAndThreadgroup:mPipeline gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name queue:backend->queue()];
mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret)); mThreads = std::make_pair(std::get<0>(ret), std::get<1>(ret));
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalConvolutionDepthwise::onFloat(const Tensor *input, const Tensor *output) { void MetalConvolutionDepthwise::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); [encoder setComputePipelineState:mPipeline];
if(backend->isCommandEncoderSet()) { [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset: TensorUtils::getDescribe(input)->extra.offset
return NO_ERROR;
}
auto func = [=](){
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset: TensorUtils::getDescribe(input)->extra.offset
atIndex:0]; atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset: TensorUtils::getDescribe(output)->extra.offset [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset: TensorUtils::getDescribe(output)->extra.offset
atIndex:1]; atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2]; [encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[encoder setBuffer:mWeight offset:0 atIndex:3]; [encoder setBuffer:mWeight offset:0 atIndex:3];
[encoder setBuffer:mBias offset:0 atIndex:4]; [encoder setBuffer:mBias offset:0 atIndex:4];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
template <typename FType, typename TType> template <typename FType, typename TType>
@ -132,7 +114,10 @@ static id<MTLBuffer> weightInBlock(MNNMetalContext *context, int group, int kh,
id<MTLBuffer> MetalConvolutionDepthwise::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) { id<MTLBuffer> MetalConvolutionDepthwise::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
return weightInBlock<float, metal_float>(context, group, kh, kw, src); if (backend->useFp16InsteadFp32()) {
return weightInBlock<float, __fp16>(context, group, kh, kw, src);
}
return weightInBlock<float, float>(context, group, kh, kw, src);
} }
class MetalConvolutionDepthwiseCreator : public MetalBackend::Creator { class MetalConvolutionDepthwiseCreator : public MetalBackend::Creator {

View File

@ -1,43 +0,0 @@
//
// MetalConvolutionGEMM.hpp
// MNN
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MetalConvolutionGEMM_hpp
#define MetalConvolutionGEMM_hpp
#import "MetalConvolutionCommon.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
class MetalConvolutionGEMM : public MetalConvolutionCommon {
public:
static bool isValid(const Convolution2D *conv, const Tensor *input);
MetalConvolutionGEMM(Backend *backend, const Tensor *input, const MNN::Op *op);
virtual ~MetalConvolutionGEMM() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
protected:
virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override;
virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override;
private:
id<MTLBuffer> mShapeBuffer = nil;
std::shared_ptr<Tensor> mTempInput;
std::shared_ptr<Tensor> mTempOutput;
id<MTLComputePipelineState> mPipelineGEMM;
std::pair<MTLSize, MTLSize> mGemm;
id<MTLComputePipelineState> mPipelineIm2Col;
std::pair<MTLSize, MTLSize> mIm2Col;
id<MTLComputePipelineState> mPipelineCol2Im;
std::pair<MTLSize, MTLSize> mCol2Im;
};
} // namespace MNN
#endif /* MNN_METAL_ENABLED */
#endif /* MetalConvolutionGEMM_hpp */

View File

@ -1,214 +0,0 @@
//
// MetalConvolutionGEMM.mm
// MNN
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MetalConvolutionGEMM.hpp"
#import "core/Macro.h"
#import "core/Macro.h"
#import "backend/metal/MetalBackend.hpp"
#import "backend/metal/MetalConvolution.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
bool MetalConvolutionGEMM::isValid(const Convolution2D *conv, const Tensor *input) {
auto common = conv->common();
auto kx = common->kernelX(), ky = common->kernelY();
if (kx == 1 || ky == 1 || common->group() != 1) {
return false;
}
auto oc = common->outputCount();
if (oc <= 16) {
return false;
}
auto iw = input->width(), ih = input->height(), ic = input->channel();
if (iw * ih * ic <= 16384) {
return false;
}
auto sx = common->strideX(), ow = (iw - kx + 1) / sx;
auto sy = common->strideY(), oh = (ih - ky + 1) / sy;
if ((iw * ih * ic) / (ow * oh * oc) > 4) {
return false;
}
auto unit = conv->quanParameter() != nullptr ? sizeof(float) : sizeof(metal_float);
auto iz = UP_DIV(ic, 4), oz = UP_DIV(oc, 4), batch = input->batch();
return UP_DIV(ow * oh * batch, 4) * kx * ky * iz * 16 * sizeof(metal_float) < (2 << 20) && // tmp input
UP_DIV(ow * oh * batch, 4) * oz * 16 * unit < (2 << 20); // tmp output
}
MetalConvolutionGEMM::MetalConvolutionGEMM(Backend *backend, const Tensor *input, const MNN::Op *op)
: MetalConvolutionCommon(backend, op) {
loadWeight(op->main_as_Convolution2D());
}
ErrorCode MetalConvolutionGEMM::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
// prepare
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0], output = outputs[0];
auto iw = input->width();
auto ih = input->height();
auto ic_4 = UP_DIV(input->channel(), 4);
auto ow = output->width();
auto oh = output->height();
auto oc_4 = UP_DIV(output->channel(), 4);
auto ob = output->batch();
auto pads = ConvolutionCommon::convolutionPad(input, output, mOp->main_as_Convolution2D()->common());
auto padX = pads.first;
auto padY = pads.second;
// create const buffer
int constants[] = {iw,
ih,
iw * ih,
ic_4,
ow,
oh,
ow * oh,
oc_4,
ob,
mKernelX,
mKernelY,
mKernelX * mKernelY,
mStrideX,
mStrideY,
padX,
padY,
mDilateX,
mDilateY,
mActivationType};
mConstBuffer = backend->getConstBuffer(sizeof(constants));
::memcpy(mConstBuffer.contents, constants, sizeof(constants));
// create mat mul const buffer
int shapes[] = {UP_DIV(ow * oh * ob, 4), oc_4, mKernelX * mKernelY * ic_4, 1};
mShapeBuffer = [context newDeviceBuffer:sizeof(shapes) bytes:shapes access:CPUWriteOnly];
// accquire space for source & dst
int is = UP_DIV(ow * oh * ob, 4) * mKernelX * mKernelY * ic_4 * 16 * sizeof(metal_float) / sizeof(uint8_t);
int os = UP_DIV(ow * oh * ob, 4) * oc_4 * 16 * sizeof(metal_float) / sizeof(uint8_t);
mTempInput.reset(Tensor::createDevice<uint8_t>(std::vector<int>{is}));
mTempOutput.reset(Tensor::createDevice<uint8_t>(std::vector<int>{os}));
if (!backend->onAcquireBuffer(mTempInput.get(), Backend::DYNAMIC) ||
!backend->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC)) {
return OUT_OF_MEMORY;
}
backend->onReleaseBuffer(mTempInput.get(), Backend::DYNAMIC);
backend->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC);
mPipelineGEMM = [context pipelineWithName:@"matmul4x4"];
mPipelineIm2Col = [context pipelineWithName:@"conv_im2col"];
mPipelineCol2Im = [context pipelineWithName:@"conv_col2im"];
NSUInteger gw = UP_DIV(output->width() * output->height() * output->batch(), 4);
NSUInteger gh = UP_DIV(output->channel(), 4);
{
NSUInteger gid_x = gw;
NSUInteger gid_y = gh;
NSUInteger gid_z = 1;
NSArray *arr = [NSArray arrayWithObjects:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempInput->deviceId())->getBuffer(),
(id<MTLBuffer>)(((MetalRuntimeAllocator::MetalBufferAlloc *)mTempOutput->deviceId()))->getBuffer(), mWeight, mShapeBuffer, nil];
std::string name = "matmul4x4";
MetalRuntime *rt = (MetalRuntime *)backend->runtime();
auto ret = [context getGridAndThreadgroup:mPipelineGEMM gid:MTLSizeMake(gid_x, gid_y, gid_z) loop:10 buffer:arr runtime:rt shaderName:name];
mGemm = std::make_pair(std::get<0>(ret), std::get<1>(ret));
}
mIm2Col = [context computeBestGroupAndLocal:mPipelineIm2Col threads:{(NSUInteger)ow, (NSUInteger)oh, (NSUInteger)ic_4*ob}];
mCol2Im = [context computeBestGroupAndLocal:mPipelineCol2Im threads:{(NSUInteger)ow, (NSUInteger)oh, (NSUInteger)oc_4*ob}];
return NO_ERROR;
}
ErrorCode MetalConvolutionGEMM::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
return onFloat(inputs[0], outputs[0]);
}
ErrorCode MetalConvolutionGEMM::onFloat(const Tensor *input, const Tensor *output) {
auto backend = static_cast<MetalBackend *>(this->backend());
if(backend->isCommandEncoderSet()) {
return NO_ERROR;
}
auto func = [=](){
auto encoder = backend->encoder();
{ // im2col
[encoder setComputePipelineState:mPipelineIm2Col];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempInput->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempInput.get())->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[encoder dispatchThreadgroups:mIm2Col.first threadsPerThreadgroup:mIm2Col.second];
}
{ // gemm
[encoder setComputePipelineState:mPipelineGEMM];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempInput->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempInput.get())->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempOutput->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempOutput.get())->extra.offset atIndex:1];
[encoder setBuffer:mWeight offset:0 atIndex:2];
[encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
[encoder dispatchThreadgroups:mGemm.first threadsPerThreadgroup:mGemm.second];
}
{ // col2im
[encoder setComputePipelineState:mPipelineCol2Im];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempOutput->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempOutput.get())->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mBias offset:0 atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[encoder dispatchThreadgroups:mCol2Im.first threadsPerThreadgroup:mCol2Im.second];
}
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
}
template <typename FType, typename TType>
static id<MTLBuffer> weightInBlock(MNNMetalContext *context, int group, int oc, int ic, int kh, int kw,
const FType *src) {
auto oz = UP_DIV(oc, 4);
auto iz = UP_DIV(ic, 4);
auto buffer = [context newDeviceBuffer:oz * iz * kw * kh * 16 * sizeof(TType) access:CPUWriteOnly];
auto dst = (TType *)buffer.contents;
for (int o = 0; o < oc; o++) {
auto zo = o / 4, ro = o % 4;
auto o_dst = dst + zo * iz * kh * kw * 16 + ro; // o/4 x 4
#pragma clang loop vectorize(enable)
for (int i = 0; i < ic; i++) {
auto zi = i / 4, ri = i % 4;
auto i_dst = o_dst + zi * kh * kw * 16 + ri * 4; // i/4 x 4
#pragma clang loop vectorize(enable)
for (int h = 0; h < kh; h++) {
#pragma clang loop vectorize(enable) unroll(enable)
for (int w = 0; w < kw; w++) {
// to [g][o/4][i/4][h][w][16]
// from [g][o][i][h][w]
i_dst[(h * kw + w) * 16] = *src++;
}
}
}
}
return buffer;
}
id<MTLBuffer> MetalConvolutionGEMM::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
return weightInBlock<float, metal_float>(context, group, oc, ic, kh, kw, src);
}
} // namespace MNN
#endif /* MNN_METAL_ENABLED */

View File

@ -22,7 +22,7 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
protected: protected:
virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override; virtual void onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) override;
virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override; virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override;
private: private:

View File

@ -110,10 +110,11 @@ ErrorCode MetalConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs
mOutputTransformThreads.width = uw; mOutputTransformThreads.width = uw;
mOutputTransformThreads.height = uh; mOutputTransformThreads.height = uh;
mOutputTransformThreads.depth = oz; mOutputTransformThreads.depth = oz;
int bytes = backend->useFp16InsteadFp32() ? 2 : 4;
// accquire space // accquire space
int is = mSrcUnit * mSrcUnit * us * iz * 16 * sizeof(metal_float) / sizeof(uint8_t); int is = mSrcUnit * mSrcUnit * us * iz * 16 * bytes;
int os = mSrcUnit * mSrcUnit * us * oz * 16 * sizeof(metal_float) / sizeof(uint8_t); int os = mSrcUnit * mSrcUnit * us * oz * 16 * bytes;
mTempSrc.reset(Tensor::createDevice<uint8_t>(std::vector<int>{is})); mTempSrc.reset(Tensor::createDevice<uint8_t>(std::vector<int>{is}));
mTempDst.reset(Tensor::createDevice<uint8_t>(std::vector<int>{os})); mTempDst.reset(Tensor::createDevice<uint8_t>(std::vector<int>{os}));
backend->onAcquireBuffer(mTempSrc.get(), Backend::DYNAMIC); backend->onAcquireBuffer(mTempSrc.get(), Backend::DYNAMIC);
@ -124,51 +125,33 @@ ErrorCode MetalConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalConvolutionWinograd::onFloat(const Tensor *input, const Tensor *output) { void MetalConvolutionWinograd::onFloat(const Tensor *input, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context(); auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCommandEncoderSet()) { { // transform
return NO_ERROR; auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[context dispatchEncoder:encoder threads:mInputTransformThreads bandwidth:bandwidth];
}
{ // gemm
auto bandwidth = [context load:@"matmul4x4" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:1];
[encoder setBuffer:mWeight offset:0 atIndex:2];
[encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder threads:mMatMulThreads bandwidth:bandwidth];
}
{ // transform
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:0];
[encoder setBuffer:mBias offset:0 atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder threads:mOutputTransformThreads bandwidth:bandwidth];
} }
auto func = [=](){
auto encoder = backend->encoder();
{ // transform
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" encoder:encoder];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:1];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[context dispatchEncoder:encoder threads:mInputTransformThreads bandwidth:bandwidth];
}
{ // gemm
auto bandwidth = [context load:@"matmul4x4" encoder:encoder];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempSrc->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempSrc.get())->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:1];
[encoder setBuffer:mWeight offset:0 atIndex:2];
[encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder threads:mMatMulThreads bandwidth:bandwidth];
}
{ // transform
auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" encoder:encoder];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)mTempDst->deviceId())->getBuffer() offset:TensorUtils::getDescribe(mTempDst.get())->extra.offset atIndex:0];
[encoder setBuffer:mBias offset:0 atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder threads:mOutputTransformThreads bandwidth:bandwidth];
}
MNN_PRINT_ENCODER(context, encoder);
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
id<MTLBuffer> MetalConvolutionWinograd::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) { id<MTLBuffer> MetalConvolutionWinograd::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
@ -179,18 +162,23 @@ id<MTLBuffer> MetalConvolutionWinograd::weightForFloat(int group, int oc, int ic
std::shared_ptr<Tensor> dstWeight = generater.allocTransformWeight(srcWeight.get(), 4, 4); std::shared_ptr<Tensor> dstWeight = generater.allocTransformWeight(srcWeight.get(), 4, 4);
generater.transformWeight(dstWeight.get(), srcWeight.get()); generater.transformWeight(dstWeight.get(), srcWeight.get());
#if MNN_METAL_FULL_PRECISION int bytenumber = 4;
auto bytes = dstWeight->host<metal_float>(); void* bytes = nullptr;
#else std::shared_ptr<Tensor> dstWeightHalf;
std::shared_ptr<Tensor> dstWeightHalf(Tensor::create<int16_t>(dstWeight->shape())); if (backend->useFp16InsteadFp32()) {
auto f32 = dstWeight->host<float>(); dstWeightHalf.reset(Tensor::create<int16_t>(dstWeight->shape()));
auto f16 = dstWeightHalf->host<metal_float>(); auto f32 = dstWeight->host<float>();
for (int i = 0; i < dstWeight->elementSize(); ++i) { auto f16 = dstWeightHalf->host<__fp16>();
f16[i] = f32[i]; for (int i = 0; i < dstWeight->elementSize(); ++i) {
f16[i] = f32[i];
}
bytes = dstWeightHalf->host<void>();
bytenumber = 2;
} else {
bytes = dstWeight->host<float>();
bytenumber = 4;
} }
auto bytes = dstWeightHalf->host<metal_float>(); return [context newDeviceBuffer:4 * UP_DIV(ic, 4) * UP_DIV(oc, 4) * mSrcUnit * mSrcUnit * 4 * bytenumber
#endif
return [context newDeviceBuffer:4 * UP_DIV(ic, 4) * UP_DIV(oc, 4) * mSrcUnit * mSrcUnit * 4 * sizeof(metal_float)
bytes:bytes bytes:bytes
access:CPUWriteOnly]; access:CPUWriteOnly];
} }

View File

@ -9,19 +9,17 @@
#ifndef MetalDeconvolution_hpp #ifndef MetalDeconvolution_hpp
#define MetalDeconvolution_hpp #define MetalDeconvolution_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MNN_generated.h" #include "MNN_generated.h"
#import "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalDeconvolution : public Execution { class MetalDeconvolution : public MetalExecution {
public: public:
MetalDeconvolution(Backend *backend, const MNN::Op *op); MetalDeconvolution(Backend *backend, const MNN::Op *op);
virtual ~MetalDeconvolution() = default; virtual ~MetalDeconvolution() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private: private:
bool mDepthwise = false; bool mDepthwise = false;

View File

@ -38,13 +38,9 @@ static id<MTLBuffer> weightForDeconv(MNNMetalContext *context, int group, int oc
auto dst = (TType *)buffer.contents; auto dst = (TType *)buffer.contents;
for (int g = 0; g < group; g++) { for (int g = 0; g < group; g++) {
#pragma clang loop vectorize(enable)
for (int i = 0; i < gic; i++) { for (int i = 0; i < gic; i++) {
#pragma clang loop vectorize(enable)
for (int o = 0; o < goc; o++) { for (int o = 0; o < goc; o++) {
#pragma clang loop vectorize(enable)
for (int h = 0; h < kh; h++) { for (int h = 0; h < kh; h++) {
#pragma clang loop vectorize(enable) unroll(enable)
for (int w = 0; w < kw; w++) { for (int w = 0; w < kw; w++) {
auto zo = o / 4, ro = o % 4; auto zo = o / 4, ro = o % 4;
auto zi = i / 4, ri = i % 4; auto zi = i / 4, ri = i % 4;
@ -85,7 +81,8 @@ static id<MTLBuffer> weightForDepthwise(MNNMetalContext *context, int group, int
return buffer; return buffer;
} }
static id<MTLBuffer> weightForDeconv(MNNMetalContext *context, bool depthwise, const Convolution2D *deconv, template <typename TType>
id<MTLBuffer> weightForDeconv(MNNMetalContext *context, bool depthwise, const Convolution2D *deconv,
ConvolutionCommon::Int8Common *qnt) { ConvolutionCommon::Int8Common *qnt) {
auto size = qnt ? qnt->weightFloat.size() : deconv->weight()->size(); auto size = qnt ? qnt->weightFloat.size() : deconv->weight()->size();
auto common = deconv->common(); auto common = deconv->common();
@ -95,31 +92,43 @@ static id<MTLBuffer> weightForDeconv(MNNMetalContext *context, bool depthwise, c
auto oc = common->outputCount(); auto oc = common->outputCount();
auto ic = size / kw / kh / (oc / group); auto ic = size / kw / kh / (oc / group);
if (depthwise) { if (depthwise) {
return weightForDepthwise<float, metal_float>(context, group, kh, kw, return weightForDepthwise<float, TType>(context, group, kh, kw,
qnt ? qnt->weightFloat.get() : deconv->weight()->data()); qnt ? qnt->weightFloat.get() : deconv->weight()->data());
} else { } else {
return weightForDeconv<float, metal_float>(context, group, oc, ic, kh, kw, return weightForDeconv<float, TType>(context, group, oc, ic, kh, kw,
qnt ? qnt->weightFloat.get() : deconv->weight()->data()); qnt ? qnt->weightFloat.get() : deconv->weight()->data());
} }
} }
static id<MTLBuffer> biasForDeconv(MNNMetalContext *context, const Convolution2D *deconv) { static id<MTLBuffer> biasForDeconv(MNNMetalContext *context, const Convolution2D *deconv, bool fp16) {
auto bias = deconv->bias(); auto bias = deconv->bias();
if (!bias || bias->size() == 0) if (!bias || bias->size() == 0)
return [context newDeviceBuffer:0 access:CPUTransparent]; return [context newDeviceBuffer:0 access:CPUTransparent];
auto oc = deconv->common()->outputCount(); auto oc = deconv->common()->outputCount();
auto buffer = [context newDeviceBuffer:UP_DIV(oc, 4) * 4 * sizeof(metal_float) access:CPUWriteOnly]; int bytes = 4;
if (fp16) {
bytes = 2;
}
auto buffer = [context newDeviceBuffer:UP_DIV(oc, 4) * 4 * bytes access:CPUWriteOnly];
auto src = bias->data(); auto src = bias->data();
auto dst = (metal_float *)buffer.contents; if (fp16) {
#pragma clang loop vectorize(enable) unroll(enable) auto dst = (__fp16 *)buffer.contents;
for (int i = 0; i < oc; i++) for (int i = 0; i < oc; i++) {
dst[i] = src[i]; dst[i] = src[i];
}
} else {
auto dst = (float *)buffer.contents;
for (int i = 0; i < oc; i++) {
dst[i] = src[i];
}
}
return buffer; return buffer;
} }
MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : Execution(backend) { MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : MetalExecution(backend) {
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
auto mtbn = static_cast<MetalBackend *>(backend);
auto deconv = op->main_as_Convolution2D(); auto deconv = op->main_as_Convolution2D();
auto common = deconv->common(); auto common = deconv->common();
mOp = op; mOp = op;
@ -141,12 +150,16 @@ MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : Ex
if (deconv->quanParameter()) { if (deconv->quanParameter()) {
qnt = ConvolutionCommon::load(deconv, backend, true); qnt = ConvolutionCommon::load(deconv, backend, true);
} }
mWeight = weightForDeconv(context, mDepthwise, deconv, qnt.get()); if (mtbn->useFp16InsteadFp32()) {
mBias = biasForDeconv(context, deconv); mWeight = weightForDeconv<__fp16>(context, mDepthwise, deconv, qnt.get());
if (mDepthwise) {
mPipeline = [context pipelineWithName:@"deconv_depthwise"];
} else { } else {
mPipeline = [context pipelineWithName:@"deconv"]; mWeight = weightForDeconv<float>(context, mDepthwise, deconv, qnt.get());
}
mBias = biasForDeconv(context, deconv, mtbn->useFp16InsteadFp32());
if (mDepthwise) {
mPipeline = [context pipelineWithName:@"deconv_depthwise" fp16:mtbn->useFp16InsteadFp32()];
} else {
mPipeline = [context pipelineWithName:@"deconv" fp16:mtbn->useFp16InsteadFp32()];
} }
} }
@ -198,35 +211,15 @@ ErrorCode MetalDeconvolution::onResize(const std::vector<Tensor *> &inputs, cons
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalDeconvolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalDeconvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend());
if(backend->isCommandEncoderSet()) {
return NO_ERROR;
}
auto func = [=](){
auto input = inputs[0], output = outputs[0]; auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
auto encoder = backend->encoder(); [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setComputePipelineState:mPipeline]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; [encoder setBuffer:mConstBuffer offset:0 atIndex:2];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; [encoder setBuffer:mWeight offset:0 atIndex:3];
[encoder setBuffer:mConstBuffer offset:0 atIndex:2]; [encoder setBuffer:mBias offset:0 atIndex:4];
[encoder setBuffer:mWeight offset:0 atIndex:3]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
[encoder setBuffer:mBias offset:0 atIndex:4];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalDeconvolutionCreator : public MetalBackend::Creator { class MetalDeconvolutionCreator : public MetalBackend::Creator {

View File

@ -19,11 +19,6 @@
#import <float.h> #import <float.h>
#endif #endif
#if (TARGET_OS_IPHONE && TARGET_OS_SIMULATOR)
#undef MNN_METAL_ENABLED
#define MNN_METAL_ENABLED 0
#endif
#endif #endif
#ifndef MNN_METAL_DEBUG #ifndef MNN_METAL_DEBUG
#if DEBUG #if DEBUG
@ -34,14 +29,6 @@
#endif #endif
#define MNN_METAL_BENCHMARK 0 #define MNN_METAL_BENCHMARK 0
#define MNN_METAL_FULL_PRECISION 0 // should edit in metal too
#if MNN_METAL_FULL_PRECISION || !defined(__FLT16_EPSILON__)
typedef float metal_float;
#define MNNMetalPixelFormatRGBAFloat MTLPixelFormatRGBA32Float
#else
typedef __fp16 metal_float;
#define MNNMetalPixelFormatRGBAFloat MTLPixelFormatRGBA16Float
#endif
#endif /* MetalDefine_h */ #endif /* MetalDefine_h */

View File

@ -8,23 +8,21 @@
#ifndef MetalEltwise_hpp #ifndef MetalEltwise_hpp
#define MetalEltwise_hpp #define MetalEltwise_hpp
#import "MetalExecution.hpp"
#import "core/Execution.hpp"
#import "MNN_generated.h" #import "MNN_generated.h"
#import "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalEltwise : public Execution { class MetalEltwise : public MetalExecution {
public: public:
MetalEltwise(Backend *backend, EltwiseType type); MetalEltwise(Backend *backend, EltwiseType type);
virtual ~MetalEltwise() = default; virtual ~MetalEltwise() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private: private:
void encode(const Tensor *input0, const Tensor *input1, const Tensor *output); void encode(const Tensor *input0, const Tensor *input1, const Tensor *output, id<MTLComputeCommandEncoder> encoder);
id<MTLComputePipelineState> mPipeline; id<MTLComputePipelineState> mPipeline;
id<MTLBuffer> mConst; id<MTLBuffer> mConst;
std::pair<MTLSize, MTLSize> mThreads; std::pair<MTLSize, MTLSize> mThreads;

View File

@ -14,7 +14,7 @@
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
MetalEltwise::MetalEltwise(Backend *backend, EltwiseType type) : Execution(backend) { MetalEltwise::MetalEltwise(Backend *backend, EltwiseType type) : MetalExecution(backend) {
auto metal = static_cast<MetalBackend *>(backend); auto metal = static_cast<MetalBackend *>(backend);
auto context = (__bridge MNNMetalContext *)metal->context(); auto context = (__bridge MNNMetalContext *)metal->context();
mConst = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; mConst = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
@ -32,7 +32,7 @@ MetalEltwise::MetalEltwise(Backend *backend, EltwiseType type) : Execution(backe
default: default:
break; break;
} }
mPipeline = [context pipelineWithName:kernel]; mPipeline = [context pipelineWithName:kernel fp16:metal->useFp16InsteadFp32()];
} }
ErrorCode MetalEltwise::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode MetalEltwise::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
((int*)(mConst.contents))[0] = outputs[0]->elementSize(); ((int*)(mConst.contents))[0] = outputs[0]->elementSize();
@ -43,9 +43,7 @@ ErrorCode MetalEltwise::onResize(const std::vector<Tensor *> &inputs, const std:
return NO_ERROR; return NO_ERROR;
} }
void MetalEltwise::encode(const Tensor *input0, const Tensor *input1, const Tensor *output) { void MetalEltwise::encode(const Tensor *input0, const Tensor *input1, const Tensor *output, id<MTLComputeCommandEncoder> encoder) {
auto metal = static_cast<MetalBackend *>(this->backend());
auto encoder = metal->encoder();
[encoder setComputePipelineState:mPipeline]; [encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1]; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
@ -54,30 +52,12 @@ void MetalEltwise::encode(const Tensor *input0, const Tensor *input1, const Tens
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
} }
ErrorCode MetalEltwise::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalEltwise::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto output = outputs[0];
encode(inputs[0], inputs[1], output, encoder);
if(backend->isCommandEncoderSet()) { for (int i = 2; i < inputs.size(); i++) {
return NO_ERROR; encode(inputs[i], output, output, encoder);
} }
auto func = [=](){
auto output = outputs[0];
encode(inputs[0], inputs[1], output);
for (int i = 2; i < inputs.size(); i++) {
encode(inputs[i], output, output);
}
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalEltwiseCreator : public MetalBackend::Creator { class MetalEltwiseCreator : public MetalBackend::Creator {

View File

@ -0,0 +1,21 @@
#ifndef MetalExecution_hpp
#define MetalExecution_hpp
#include "core/Execution.hpp"
#import "MetalDefine.h"
#include <string>
#if MNN_METAL_ENABLED
namespace MNN {
class MetalExecution : public Execution {
public:
MetalExecution(Backend *backend);
virtual ~MetalExecution() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) = 0;
};
} // namespace MNN
#endif /* MNN_METAL_ENABLED */
#endif

View File

@ -0,0 +1,32 @@
#include "MetalExecution.hpp"
#import "backend/metal/MetalBackend.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
MetalExecution::MetalExecution(Backend *backend) : Execution(backend) {
// Do nothing
}
ErrorCode MetalExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto backend = static_cast<MetalBackend *>(this->backend());
if(backend->isCommandEncoderSet()) {
return NO_ERROR;
}
auto func = [=](){
auto encoder = backend->encoder_for_net();
this->onEncode(inputs, outputs, encoder);
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
backend->commit_net();
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
}
};
#endif

View File

@ -9,18 +9,17 @@
#ifndef MetalFuse_hpp #ifndef MetalFuse_hpp
#define MetalFuse_hpp #define MetalFuse_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MNN_generated.h" #import "MNN_generated.h"
#import "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalFuse : public Execution { class MetalFuse : public MetalExecution {
public: public:
MetalFuse(Backend *backend, const Op* op); MetalFuse(Backend *backend, const Op* op);
virtual ~MetalFuse() = default; virtual ~MetalFuse() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private: private:

View File

@ -16,7 +16,7 @@
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
// #define MNN_FUSE_DEBUG // #define MNN_FUSE_DEBUG
MetalFuse::MetalFuse(Backend *backend, const Op* op) : Execution(backend), mOp(op) { MetalFuse::MetalFuse(Backend *backend, const Op* op) : MetalExecution(backend), mOp(op) {
auto mtbn = static_cast<MetalBackend *>(backend); auto mtbn = static_cast<MetalBackend *>(backend);
auto context = (__bridge MNNMetalContext *)mtbn->context(); auto context = (__bridge MNNMetalContext *)mtbn->context();
mConstBuffer = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly]; mConstBuffer = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly];
@ -27,9 +27,7 @@ MetalFuse::MetalFuse(Backend *backend, const Op* op) : Execution(backend), mOp(o
#ifdef MNN_FUSE_DEBUG #ifdef MNN_FUSE_DEBUG
MNN_PRINT("MetalFuse srcCode:\n%s\n", srcCode); MNN_PRINT("MetalFuse srcCode:\n%s\n", srcCode);
#endif #endif
auto source = [[NSString alloc] initWithUTF8String:ss.str().c_str()]; mPipeline = mtbn->makeComputePipelineWithSourceOption(ss.str().c_str(), extra->type()->c_str(), nil);
auto name = [[NSString alloc] initWithUTF8String:extra->type()->c_str()];
mPipeline = [context pipelineWithSource:source name:name];
} }
ErrorCode MetalFuse::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode MetalFuse::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
@ -43,61 +41,186 @@ ErrorCode MetalFuse::onResize(const std::vector<Tensor *> &inputs, const std::ve
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalFuse::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalFuse::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
int i = 0;
for (; i < inputs.size(); i++) {
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[i]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[i])->extra.offset atIndex:i];
}
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:i++];
[encoder setBuffer:mConstBuffer offset:0 atIndex:i++];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
#ifdef MNN_FUSE_DEBUG
auto dump = [&backend](const Tensor* t) {
auto outDimType = t->getDimensionType();
auto expectTensor = new MNN::Tensor(t, outDimType);
backend->onCopyBuffer(t, expectTensor);
MNN_PRINT("[ ");
for (int i = 0; i < 10; i++) {
MNN_PRINT("%f, ", expectTensor->host<float>()[i]);
}
MNN_PRINT(" ]\n");
delete expectTensor;
};
{
MNN_PRINT("=============================\n");
for (int i = 0; i < inputs.size(); i++) {
inputs[i]->wait(Tensor::MAP_TENSOR_READ, true);
dump(inputs[i]);
}
output->wait(Tensor::MAP_TENSOR_READ, true);
dump(output);
MNN_PRINT("=============================\n");
}
#endif
}
static bool _isStandardFuse(const Op* op) {
if (op->type() != OpType_Extra) {
return false;
}
if (nullptr == op->main_as_Extra()) {
return false;
}
auto extra = op->main_as_Extra();
if (nullptr == extra->attr()) {
return false;
}
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "version") {
if (nullptr != attr->s()) {
std::string cont = attr->s()->str();
return cont == "common";
}
return false;
}
}
return false;
}
class MetalFuseV2 : public MetalExecution {
public:
MetalFuseV2(Backend *backend, const Op* op, int outputSize, int inputSize) : MetalExecution(backend) {
mOutputBinding.resize(outputSize);
mInputBinding.resize(inputSize);
auto mtbn = static_cast<MetalBackend*>(backend);
auto context = (__bridge MNNMetalContext *)mtbn->context();
auto extra = op->main_as_Extra();
// Find shader
const char* source = nil;
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "metal") {
source = attr->s()->c_str();
break;
}
}
mPipeline = mtbn->makeComputePipelineWithSourceOption(source, "main0", nil);
// Init size
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "group_size") {
auto ptr = attr->tensor()->int32s()->data();
mGroupSize.width = ptr[0];
mGroupSize.height = ptr[1];
mGroupSize.depth = ptr[2];
break;
}
}
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "local_size") {
auto ptr = attr->tensor()->int32s()->data();
mThreadSize.width = ptr[0];
mThreadSize.height = ptr[1];
mThreadSize.depth = ptr[2];
break;
}
}
int maxIndex = -1;
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "input") {
maxIndex = ALIMAX(maxIndex, attr->i());
} else if (attr->key()->str() == "const") {
maxIndex = ALIMAX(maxIndex, attr->i());
}
}
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "input") {
auto list = attr->list()->i()->data();
if (list[1] >= 0) {
if (0 == list[0]) {
mInputBinding[list[1]] = attr->i();
} else {
mOutputBinding[list[1]] = attr->i();
}
}
continue;
}
if (attr->key()->str() == "const") {
auto b = attr->tensor();
void* result = nullptr;
size_t bufferSize = 0;
switch (b->dataType()) {
case DataType_DT_FLOAT:
result = (void*)b->float32s()->Data();
bufferSize = b->float32s()->size() * sizeof(float);
break;
case DataType_DT_INT32:
result = (void*)b->int32s()->Data();
bufferSize = b->int32s()->size() * sizeof(float);
break;
default:
MNN_ASSERT(false);
break;
}
// TODO: Fuse All Const Buffer to One buffer
id<MTLBuffer> constBuffer = [context newDeviceBuffer:bufferSize access:CPUWriteOnly];
::memcpy([constBuffer contents], result, bufferSize);
mConstIndides.emplace_back(std::make_pair(attr->i(), std::make_pair(constBuffer, 0)));
continue;
}
}
}
virtual ~MetalFuseV2() = default;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override {
[encoder setComputePipelineState:mPipeline];
for (int i=0; i<inputs.size(); ++i) {
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[i]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[i])->extra.offset atIndex:mInputBinding[i]];
}
for (int i=0; i<outputs.size(); ++i) {
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)outputs[i]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(outputs[i])->extra.offset atIndex:mOutputBinding[i]];
}
for (int i=0; i<mConstIndides.size(); ++i) {
[encoder setBuffer:mConstIndides[i].second.first offset:0 atIndex:mConstIndides[i].first];
}
[encoder dispatchThreadgroups:mGroupSize threadsPerThreadgroup:mThreadSize];
}
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
auto backend = static_cast<MetalBackend *>(this->backend());
if(backend->isCommandEncoderSet()) {
return NO_ERROR; return NO_ERROR;
} }
private:
auto func = [=](){ MTLSize mGroupSize;
auto input = inputs[0], output = outputs[0]; MTLSize mThreadSize;
auto encoder = backend->encoder(); std::vector<int> mInputBinding;
[encoder setComputePipelineState:mPipeline]; std::vector<int> mOutputBinding;
int i = 0; std::vector<std::pair<int, std::pair<id<MTLBuffer>, size_t>>> mConstIndides;
for (; i < inputs.size(); i++) { id<MTLComputePipelineState> mPipeline;
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[i]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[i])->extra.offset atIndex:i]; };
}
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:i++];
[encoder setBuffer:mConstBuffer offset:0 atIndex:i++];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
#ifdef MNN_FUSE_DEBUG
auto dump = [&backend](const Tensor* t) {
auto outDimType = t->getDimensionType();
auto expectTensor = new MNN::Tensor(t, outDimType);
backend->onCopyBuffer(t, expectTensor);
MNN_PRINT("[ ");
for (int i = 0; i < 10; i++) {
MNN_PRINT("%f, ", expectTensor->host<float>()[i]);
}
MNN_PRINT(" ]\n");
delete expectTensor;
};
{
MNN_PRINT("=============================\n");
for (int i = 0; i < inputs.size(); i++) {
inputs[i]->wait(Tensor::MAP_TENSOR_READ, true);
dump(inputs[i]);
}
output->wait(Tensor::MAP_TENSOR_READ, true);
dump(output);
MNN_PRINT("=============================\n");
}
#endif
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
}
class MetalFuseCreator : public MetalBackend::Creator { class MetalFuseCreator : public MetalBackend::Creator {
public: public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const { virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
if (_isStandardFuse(op)) {
return new MetalFuseV2(backend, op, (int)outputs.size(), (int)inputs.size());
}
return new MetalFuse(backend, op); return new MetalFuse(backend, op);
} }
}; };

View File

@ -9,18 +9,18 @@
#ifndef MetalGridSample_hpp #ifndef MetalGridSample_hpp
#define MetalGridSample_hpp #define MetalGridSample_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MNN_generated.h" #import "MNN_generated.h"
#import "MetalBackend.hpp" #import "MetalBackend.hpp"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalGridSample : public Execution { class MetalGridSample : public MetalExecution {
public: public:
MetalGridSample(Backend *backend, const GridSample* gridSample); MetalGridSample(Backend *backend, const GridSample* gridSample);
virtual ~MetalGridSample() = default; virtual ~MetalGridSample() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private: private:

View File

@ -13,7 +13,7 @@
namespace MNN { namespace MNN {
MetalGridSample::MetalGridSample(Backend *backend, const GridSample *gridSample) MetalGridSample::MetalGridSample(Backend *backend, const GridSample *gridSample)
: Execution(backend) { : MetalExecution(backend) {
mMode = gridSample->mode(); mMode = gridSample->mode();
mPaddingMode = gridSample->paddingMode(); mPaddingMode = gridSample->paddingMode();
mAlignCorners = gridSample->alignCorners(); mAlignCorners = gridSample->alignCorners();
@ -40,7 +40,7 @@ ErrorCode MetalGridSample::onResize(const std::vector<Tensor *> &inputs,
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context(); auto context = (__bridge MNNMetalContext *)backend->context();
mPipeline = [context pipelineWithName:@"grid_sample"]; mPipeline = [context pipelineWithName:@"grid_sample" fp16:backend->useFp16InsteadFp32()];
int batches = ((int *)mParams.contents)[0]; int batches = ((int *)mParams.contents)[0];
int channels = ((int *)mParams.contents)[1]; int channels = ((int *)mParams.contents)[1];
@ -52,32 +52,13 @@ ErrorCode MetalGridSample::onResize(const std::vector<Tensor *> &inputs,
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalGridSample::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalGridSample::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); [encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[0]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[0])->extra.offset atIndex:0];
if(backend->isCommandEncoderSet()) { [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[1]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[1])->extra.offset atIndex:1];
return NO_ERROR; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)outputs[0]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(outputs[0])->extra.offset atIndex:2];
} [encoder setBuffer:mParams offset:0 atIndex:3];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto func = [=](){
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[0]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[0])->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[1]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[1])->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)outputs[0]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(outputs[0])->extra.offset atIndex:2];
[encoder setBuffer:mParams offset:0 atIndex:3];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalGridSampleCreator : public MetalBackend::Creator { class MetalGridSampleCreator : public MetalBackend::Creator {

View File

@ -9,17 +9,16 @@
#ifndef MetalInterp_hpp #ifndef MetalInterp_hpp
#define MetalInterp_hpp #define MetalInterp_hpp
#include "core/Execution.hpp" #include "MetalExecution.hpp"
#include "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalInterp : public Execution { class MetalInterp : public MetalExecution {
public: public:
MetalInterp(Backend *backend, const Op* op); MetalInterp(Backend *backend, const Op* op);
virtual ~MetalInterp() = default; virtual ~MetalInterp() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private: private:

View File

@ -15,7 +15,7 @@
namespace MNN { namespace MNN {
MetalInterp::MetalInterp(Backend *backend, const Op* op) MetalInterp::MetalInterp(Backend *backend, const Op* op)
: Execution(backend) { : MetalExecution(backend) {
auto interpParam = op->main_as_Interp(); auto interpParam = op->main_as_Interp();
auto mBk = static_cast<MetalBackend *>(this->backend()); auto mBk = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)mBk->context(); auto context = (__bridge MNNMetalContext *)mBk->context();
@ -43,12 +43,12 @@ ErrorCode MetalInterp::onResize(const std::vector<Tensor *> &inputs, const std::
((int *)mShape.contents)[6] = slice; ((int *)mShape.contents)[6] = slice;
if (mReiszeType == 2 || mReiszeType == 1) { if (mReiszeType == 2 || mReiszeType == 1) {
if (2 == mReiszeType) { if (2 == mReiszeType) {
mPipeline = [context pipelineWithName:@"resize_bilinear"]; mPipeline = [context pipelineWithName:@"resize_bilinear" fp16:backend->useFp16InsteadFp32()];
} else { } else {
mPipeline = [context pipelineWithName:@"resize_nearest"]; mPipeline = [context pipelineWithName:@"resize_nearest" fp16:backend->useFp16InsteadFp32()];
} }
} else if (mReiszeType == 3) { } else if (mReiszeType == 3) {
mPipeline = [context pipelineWithName:@"resize_cubic"]; mPipeline = [context pipelineWithName:@"resize_cubic" fp16:backend->useFp16InsteadFp32()];
} else { } else {
MNN_ASSERT(false); MNN_ASSERT(false);
} }
@ -57,36 +57,15 @@ ErrorCode MetalInterp::onResize(const std::vector<Tensor *> &inputs, const std::
return NO_ERROR; return NO_ERROR;
} }
void MetalInterp::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
ErrorCode MetalInterp::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { auto input = inputs[0], output = outputs[0];
auto backend = static_cast<MetalBackend *>(this->backend()); // encode
if(backend->isCommandEncoderSet()) { [encoder setComputePipelineState:mPipeline];
return NO_ERROR; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
} [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mShape offset:0 atIndex:2];
auto func = [=](){ [encoder setBuffer:mCordTransform offset:0 atIndex:3];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto input = inputs[0], output = outputs[0];
// encode
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mShape offset:0 atIndex:2];
[encoder setBuffer:mCordTransform offset:0 atIndex:3];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalInterpCreator : public MetalBackend::Creator { class MetalInterpCreator : public MetalBackend::Creator {

View File

@ -9,19 +9,18 @@
#ifndef MetalLayerNorm_hpp #ifndef MetalLayerNorm_hpp
#define MetalLayerNorm_hpp #define MetalLayerNorm_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MNN_generated.h" #import "MNN_generated.h"
#import "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalLayerNorm : public Execution { class MetalLayerNorm : public MetalExecution {
public: public:
MetalLayerNorm(Backend *backend, const LayerNorm *layernorm); MetalLayerNorm(Backend *backend, const LayerNorm *layernorm);
virtual ~MetalLayerNorm() = default; virtual ~MetalLayerNorm() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private: private:
int mOutside; int mOutside;

View File

@ -14,7 +14,7 @@
namespace MNN { namespace MNN {
MetalLayerNorm::MetalLayerNorm(Backend *backend, const LayerNorm *layernorm) MetalLayerNorm::MetalLayerNorm(Backend *backend, const LayerNorm *layernorm)
: Execution(backend), mGroup(layernorm->group()), : MetalExecution(backend), mGroup(layernorm->group()),
mEps(layernorm->epsilon()) { mEps(layernorm->epsilon()) {
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
@ -69,10 +69,10 @@ ErrorCode MetalLayerNorm::onResize(const std::vector<Tensor *> &inputs, const st
} }
std::sort(mAxis.begin(), mAxis.end()); std::sort(mAxis.begin(), mAxis.end());
for (int i = 0; i < rank - axis.size(); ++i) { for (int i = 0; i < rank - (int)axis.size(); ++i) {
mOutside *= input->length(i); mOutside *= input->length(i);
} }
for (int i = rank - axis.size(); i < rank; ++i) { for (int i = rank - (int)axis.size(); i < rank; ++i) {
mInside *= input->length(i); mInside *= input->length(i);
} }
@ -84,44 +84,26 @@ ErrorCode MetalLayerNorm::onResize(const std::vector<Tensor *> &inputs, const st
bool parallel = (mInside > 32) && ((mInside & 3) == 0); bool parallel = (mInside > 32) && ((mInside & 3) == 0);
mPipeline = [context pipelineWithName:parallel ? @"layernorm_x4" : @"layernorm_x1"]; mPipeline = [context pipelineWithName:parallel ? @"layernorm_x4" : @"layernorm_x1" fp16:backend->useFp16InsteadFp32()];
auto inside = parallel ? mInside/4 : mInside; auto inside = parallel ? mInside/4 : mInside;
mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)inside, (NSUInteger)mOutside, 1)]; mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger)inside, (NSUInteger)mOutside, 1)];
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalLayerNorm::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalLayerNorm::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context(); auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mShapeBuffer offset:0 atIndex:2];
[encoder setBuffer:mGammaBuffer offset:0 atIndex:3];
[encoder setBuffer:mBetaBuffer offset:0 atIndex:4];
if(backend->isCommandEncoderSet()) { [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
return NO_ERROR; MNN_PRINT_ENCODER(context, encoder);
}
auto func = [=](){
auto input = inputs[0], output = outputs[0];
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mShapeBuffer offset:0 atIndex:2];
[encoder setBuffer:mGammaBuffer offset:0 atIndex:3];
[encoder setBuffer:mBetaBuffer offset:0 atIndex:4];
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
MNN_PRINT_ENCODER(context, encoder);
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalLayerNormCreator : public MetalBackend::Creator { class MetalLayerNormCreator : public MetalBackend::Creator {

View File

@ -9,18 +9,18 @@
#ifndef MetalMatMul_hpp #ifndef MetalMatMul_hpp
#define MetalMatMul_hpp #define MetalMatMul_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MNN_generated.h" #import "MNN_generated.h"
#import "MetalBackend.hpp" #import "MetalBackend.hpp"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalMatMul : public Execution { class MetalMatMul : public MetalExecution {
public: public:
MetalMatMul(Backend *backend, const MatMul *matmul); MetalMatMul(Backend *backend, const MatMul *matmul);
virtual ~MetalMatMul() = default; virtual ~MetalMatMul() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
private: private:

View File

@ -18,7 +18,7 @@ struct matP {
int size[4]; int size[4];
int stride[4]; int stride[4];
}; };
MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul) : Execution(backend) { MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul) : MetalExecution(backend) {
mTransposeA = matmul->transposeA(); mTransposeA = matmul->transposeA();
mTransposeB = matmul->transposeB(); mTransposeB = matmul->transposeB();
auto mkbn = static_cast<MetalBackend *>(backend); auto mkbn = static_cast<MetalBackend *>(backend);
@ -57,51 +57,34 @@ ErrorCode MetalMatMul::onResize(const std::vector<Tensor *> &inputs, const std::
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalMatMul::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalMatMul::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
auto input0 = inputs[0], input1 = inputs[1], output = outputs[0];
Tensor* C = outputs[0];
auto e = C->length(0);
auto h = C->length(1);
if(backend->isCommandEncoderSet()) { if (inputs.size() > 2) {
return NO_ERROR; auto bandwidth = [context load:@"matmul_bias" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[2]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[2])->extra.offset atIndex:2];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:3];
[encoder setBuffer:mConstBuffer offset:0 atIndex:4];
[context dispatchEncoder:encoder
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
bandwidth:bandwidth];
} else {
auto bandwidth = [context load:@"matmul" encoder:encoder fp16:backend->useFp16InsteadFp32()];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
bandwidth:bandwidth];
} }
auto func = [=](){
auto input0 = inputs[0], input1 = inputs[1], output = outputs[0];
Tensor* C = outputs[0];
auto e = C->length(0);
auto h = C->length(1);
auto encoder = backend->encoder();
if (inputs.size() > 2) {
auto bandwidth = [context load:@"matmul_bias" encoder:encoder];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)inputs[2]->deviceId())->getBuffer() offset:TensorUtils::getDescribe(inputs[2])->extra.offset atIndex:2];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:3];
[encoder setBuffer:mConstBuffer offset:0 atIndex:4];
[context dispatchEncoder:encoder
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
bandwidth:bandwidth];
} else {
auto bandwidth = [context load:@"matmul" encoder:encoder];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2];
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
[context dispatchEncoder:encoder
threads:{ (NSUInteger)h, (NSUInteger)e, (NSUInteger)1 }
bandwidth:bandwidth];
}
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalMatMulCreator : public MetalBackend::Creator { class MetalMatMulCreator : public MetalBackend::Creator {

View File

@ -9,18 +9,16 @@
#ifndef MetalPReLU_hpp #ifndef MetalPReLU_hpp
#define MetalPReLU_hpp #define MetalPReLU_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalPReLU : public Execution { class MetalPReLU : public MetalExecution {
public: public:
MetalPReLU(Backend *backend, const float *slope, int count); MetalPReLU(Backend *backend, const float *slope, int count);
virtual ~MetalPReLU() = default; virtual ~MetalPReLU() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private: private:
id<MTLBuffer> mSlope; id<MTLBuffer> mSlope;

View File

@ -14,14 +14,15 @@
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
MetalPReLU::MetalPReLU(Backend *backend, const float *slope, int count) : Execution(backend) { MetalPReLU::MetalPReLU(Backend *backend, const float *slope, int count) : MetalExecution(backend) {
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
mSlope = [context newDeviceBuffer:UP_DIV(count, 4) * 4 * sizeof(float) bytes:slope access:CPUWriteOnly]; mSlope = [context newDeviceBuffer:UP_DIV(count, 4) * 4 * sizeof(float) bytes:slope access:CPUWriteOnly];
mShareChannel = 1 == count; mShareChannel = 1 == count;
if (!mShareChannel) { if (!mShareChannel) {
mShape = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly]; mShape = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly];
} }
mPipeline = [context pipelineWithName:mShareChannel ? @"prelu" : @"prelu_slopes"]; auto mtbn = static_cast<MetalBackend *>(backend);
mPipeline = [context pipelineWithName:mShareChannel ? @"prelu" : @"prelu_slopes" fp16:mtbn->useFp16InsteadFp32()];
} }
ErrorCode MetalPReLU::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode MetalPReLU::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
@ -40,35 +41,16 @@ ErrorCode MetalPReLU::onResize(const std::vector<Tensor *> &inputs, const std::v
return NO_ERROR; return NO_ERROR;
} }
ErrorCode MetalPReLU::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { void MetalPReLU::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend()); auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
if(backend->isCommandEncoderSet()) { [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
return NO_ERROR; [encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mSlope offset:0 atIndex:2];
if (!mShareChannel) {
[encoder setBuffer:mShape offset:0 atIndex:3];
} }
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto func = [=](){
auto input = inputs[0], output = outputs[0];
auto encoder = backend->encoder();
[encoder setComputePipelineState:mPipeline];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0];
[encoder setBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1];
[encoder setBuffer:mSlope offset:0 atIndex:2];
if (!mShareChannel) {
[encoder setBuffer:mShape offset:0 atIndex:3];
}
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
auto context = (__bridge MNNMetalContext *)backend->context();
if(backend->isCmdBufferCommit()) {
backend->flushEncoder();
[context commit_net];
}
};
func();
backend->addOpEncoder(func);
return NO_ERROR;
} }
class MetalPReLUCreator : public MetalBackend::Creator { class MetalPReLUCreator : public MetalBackend::Creator {

View File

@ -9,19 +9,18 @@
#ifndef MetalPooling_hpp #ifndef MetalPooling_hpp
#define MetalPooling_hpp #define MetalPooling_hpp
#import "core/Execution.hpp" #import "MetalExecution.hpp"
#import "MNN_generated.h" #import "MNN_generated.h"
#import "MetalDefine.h"
#if MNN_METAL_ENABLED #if MNN_METAL_ENABLED
namespace MNN { namespace MNN {
class MetalPooling : public Execution { class MetalPooling : public MetalExecution {
public: public:
MetalPooling(Backend *backend, const Pool *pooling); MetalPooling(Backend *backend, const Pool *pooling);
virtual ~MetalPooling() = default; virtual ~MetalPooling() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
private: private:
bool mGlobal; bool mGlobal;

Some files were not shown because too many files have changed in this diff Show More