Merge pull request #3603 from alibaba/feature/sync
android / android_build (push) Waiting to run Details
ios / ios_build (push) Has been cancelled Details
linux / linux_buil_test (push) Has been cancelled Details
macos / macos_buil_test (push) Has been cancelled Details
wiki / sync_wiki (push) Has been cancelled Details
windows / windows_build_test (push) Has been cancelled Details

MNN:Sync: Sync 3.2.0
This commit is contained in:
jxt1234 2025-06-06 09:35:38 +08:00 committed by GitHub
commit ebdada8263
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
110 changed files with 10685 additions and 2771 deletions

View File

@ -126,7 +126,7 @@ mkdir build && cd build && cmake .. -DCMAKE_OSX_ARCHITECTURES=arm64 && make -j8
- 基于脚本编译:运行脚本并开启`MNN_ARM82`选项
```
sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true"
sh package_scripts/ios/buildiOS.sh -DMNN_ARM82=ON
```
## 鸿蒙(Harmony)

View File

@ -20,9 +20,11 @@
### 图像实例分割
代码位置:`demo/exec/segment.cpp`
下载 deeplabv3 分割模型并转换到 mnn 模型
下载 deeplabv3 分割模型
[https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite](https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite)
使用 [模型转换工具](../tools/convert.md) 转换为 MNN 模型,转换时加上参数 --keepInputFormat=0 【把输入由NHWC转换为NC4HW4布局】
```bash
./segment.out model.mnn input.png result.png
```
@ -95,14 +97,14 @@ flops_info: 568.792175M
backend_info: 13
expect 983
output belong to class: 983
$ python gpu_session_demo.py mobilenet_demo/mobilenet_v1.mnn mobilenet_demo/ILSVRC2012_val_00049999.JPEG
$ python gpu_session_demo.py mobilenet_demo/mobilenet_v1.mnn mobilenet_demo/ILSVRC2012_val_00049999.JPEG
Testing gpu model calling method
Load Cache file error.
MNN use high precision
Can't Find type=3 backend, use 0 instead
Can't Find type=3 backend, use 0 instead
Run on backendtype: 13
Run on backendtype: 13
expect 983
output belong to class: 983
@ -127,7 +129,7 @@ output belong to class: 983
#### mnist
使用mnist数据训练模型并测试准确率无需下载资源用法如下
```bash
$ pip install mnist
$ pip install mnist
$ python train_mnist.py
train loss: 2.3346531
train loss: 0.28027835
@ -161,7 +163,7 @@ AttributeError: module 'MNN.nn' has no attribute 'FixModule'
#### module_save
演示了模型权值的存储和加载
```bash
$ python test_save.py
$ python test_save.py
0.0004
10
```
@ -225,4 +227,3 @@ sh ../tools/script/get_model.sh
- [视频抠图](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit)
- [SuperGlue关键点匹配](https://github.com/Hanson0910/MNNSuperGlue)
- [OCR](https://github.com/DayBreak-u/chineseocr_lite/tree/onnx/android_projects/OcrLiteAndroidMNN)
- [Bert-VITS2-MNN](https://github.com/Voine/Bert-VITS2-MNN)

View File

@ -73,14 +73,17 @@ python llmexport.py \
- 使用`--lm_quant_bit`来制定lm_head层权重的量化bit数不指定则使用`--quant_bit`的量化bit数
### 参数
执行 `python llmexport.py -h` 可查看参数:
```
usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT]
[--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT]
[--mnnconvert MNNCONVERT]
usage: llmexport.py [-h] --path PATH [--type TYPE] [--tokenizer_path TOKENIZER_PATH] [--lora_path LORA_PATH]
[--gptq_path GPTQ_PATH] [--dst_path DST_PATH] [--verbose] [--test TEST] [--export EXPORT]
[--onnx_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK]
[--lm_quant_bit LM_QUANT_BIT] [--mnnconvert MNNCONVERT] [--ppl] [--awq] [--sym] [--seperate_embed]
[--lora_split]
llm_exporter
options:
optional arguments:
-h, --help show this help message and exit
--path PATH path(`str` or `os.PathLike`):
Can be either:
@ -88,19 +91,30 @@ options:
- A path to a *directory* clone from repo like `../chatglm-6b`.
--type TYPE type(`str`, *optional*):
The pretrain llm model type.
--tokenizer_path TOKENIZER_PATH
tokenizer path, defaut is `None` mean using `--path` value.
--lora_path LORA_PATH
lora path, defaut is `None` mean not apply lora.
--gptq_path GPTQ_PATH
gptq path, defaut is `None` mean not apply gptq.
--dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`.
--verbose Whether or not to print verbose.
--test TEST test model inference with query `TEST`.
--export EXPORT export model to an onnx/mnn model.
--onnx_slim Whether or not to use onnx-slim.
--quant_bit QUANT_BIT
mnn quant bit, 4 or 8, default is 4.
--quant_block QUANT_BLOCK
mnn quant block, default is 0 mean channle-wise.
mnn quant block, 0 mean channle-wise, default is 128.
--lm_quant_bit LM_QUANT_BIT
mnn lm_head quant bit, 4 or 8, default is `quant_bit`.
--mnnconvert MNNCONVERT
local mnnconvert path, if invalid, using pymnn.
--ppl Whether or not to get all logits of input tokens.
--awq Whether or not to use awq quant.
--sym Whether or not to using symmetric quant (without zeropoint), defualt is False.
--seperate_embed For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.
--lora_split Whether or not export lora split, defualt is False.
```
### 权重读取

View File

@ -75,7 +75,7 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR_IMP(x) #x
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 3
#define MNN_VERSION_MINOR 1
#define MNN_VERSION_PATCH 4
#define MNN_VERSION_MINOR 2
#define MNN_VERSION_PATCH 0
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */

View File

@ -99,14 +99,22 @@ static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_
template <typename T>
static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* mask) {
if (seq_len == 1 || mask == nullptr) {
for (int i = 0; i < seq_len * kv_seq_len; i++) {
for (int i = 0; i < kv_seq_len; i++) {
unpack_qk[i] = unpack_qk[i] * mScale;
}
} else if (mask->getType() == halide_type_of<float>()) {
// float mask
T* fpmask_ptr = mask->host<T>();
for (int i = 0; i < seq_len * kv_seq_len; i++) {
unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i];
int offset = kv_seq_len-seq_len;
for (int i=0; i<seq_len; ++i) {
auto unpack_qki = unpack_qk + i * kv_seq_len;
auto fpmask_ptri = fpmask_ptr + i * seq_len;
for (int j=0; j<offset; ++j) {
unpack_qki[j] = unpack_qki[j] * mScale;
}
for (int j=0; j<seq_len; ++j) {
unpack_qki[offset+j] = unpack_qki[offset+j] * mScale + fpmask_ptri[j];
}
}
} else {
// int mask
@ -192,7 +200,6 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
int seq_len = query->length(1);
if (inputs.size() > 3) {
mask = inputs[3];
MNN_ASSERT(seq_len == mask->length(2));
}
int tileCount = UP_DIV(mNumHead, mThreadNum);
int group_size = mNumHead / mKvNumHead;

View File

@ -594,7 +594,7 @@ ErrorCode CPURaster::onExecute(const std::vector<Tensor *> &____inputs, const st
}
auto core = static_cast<CPUBackend*>(backend())->functions();
auto output = outputs[0];
auto bytes = CPUBackend::getBytes(backend(), output);
size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output));
auto outputEleSize = static_cast<CPUBackend*>(backend())->getTensorSize(output);
auto threadNum = static_cast<CPUBackend*>(backend())->threadNumber();
if (mSingleConvert.type > 0) {

View File

@ -72,7 +72,7 @@ struct d0
T data[1];
};
kernel void main0(device d0& uOutput [[buffer(0)]], const device s0& uInputA [[buffer(1)]], const device s1& uInputB [[buffer(2)]],
kernel void loop_matmul(device d0& uOutput [[buffer(0)]], const device s0& uInputA [[buffer(1)]], const device s1& uInputB [[buffer(2)]],
#ifdef HAS_BIAS
const device s2& uInputC [[buffer(3)]],
const device s3& uOOffset [[buffer(4)]],
@ -198,7 +198,7 @@ public:
@"HAS_BIAS":@"1",
};
}
pipeline = mtbn->makeComputePipelineWithSourceOption(gMatMulUnitTemplate, "main0", compileOptions);
pipeline = mtbn->makeComputePipelineWithSourceOption(gMatMulUnitTemplate, "loop_matmul", compileOptions);
mtbn->runtime()->insertPipeline(keys, pipeline);
}
if (nil == pipeline) {
@ -268,6 +268,7 @@ struct constBuffer
int4 extent;
int4 _step;
int4 iter;
int4 totalSize;
};
struct s1
@ -290,7 +291,7 @@ struct s0
T data[1];
};
kernel void main0(device sourceBuffer& uOutput [[buffer(0)]], const device s0& uInput [[buffer(1)]], const device s1& uSrcOffset [[buffer(2)]], const device s2& uDstOffset [[buffer(3)]], constant constBuffer& uConstant [[buffer(4)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
kernel void gather_blit(device sourceBuffer& uOutput [[buffer(0)]], const device s0& uInput [[buffer(1)]], const device s1& uSrcOffset [[buffer(2)]], const device s2& uDstOffset [[buffer(3)]], constant constBuffer& uConstant [[buffer(4)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
int3 posTmp = int3(gl_GlobalInvocationID);
if (posTmp.x < uConstant._step.w)
@ -322,7 +323,11 @@ kernel void main0(device sourceBuffer& uOutput [[buffer(0)]], const device s0& u
}
int srcOffset = (((srcBasicOffset + uConstant.stride.w) + (uConstant.stride.z * pos.z)) + (uConstant.stride.y * pos.y)) + (uConstant.stride.x * pos.x);
int dstOffset = (((dstBasicOffset + uConstant.extent.w) + (pos.x * uConstant.extent.x)) + (pos.y * uConstant.extent.y)) + (pos.z * uConstant.extent.z);
uOutput.data[dstOffset] = uInput.data[srcOffset];
if(srcOffset >= 0 && srcOffset < uConstant.totalSize.x) {
if(dstOffset >= 0 && dstOffset < uConstant.totalSize.y) {
uOutput.data[dstOffset] = uInput.data[srcOffset];
}
}
}
}
)metal";
@ -333,49 +338,139 @@ struct GatherInfo {
int extent[4];
int step[4];
int iter[4];
int totalSize[4];
};
struct InitInfo {
int srcStride[4];
int dstStride[4];
int size[4];
int totalSize[4];
};
static const char* gInitRegion = R"metal(
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct constBuffer
{
int4 srcStride;
int4 dstStride;
int4 size;
int4 totalSize;
};
kernel void set_zero(device T *out [[buffer(0)]],
const device T *in [[buffer(1)]],
constant constBuffer &info [[buffer(2)]],
uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) {
int3 gid = int3(gl_GlobalInvocationID);
if (gid.x >= info.size.x || gid.y >= info.size.y || gid.z >= info.size.z) {
return;
}
int dst_offset = (gid.z * info.size.y + gid.y) * info.size.x + gid.x;
if(dst_offset >= 0 && dst_offset < info.totalSize.y) {
out[dst_offset] = (T)0;
}
}
kernel void set_copy(device T *out [[buffer(0)]],
const device T *in [[buffer(1)]],
constant constBuffer &info [[buffer(2)]],
uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) {
int3 gid = int3(gl_GlobalInvocationID);
if (gid.x >= info.size.x || gid.y >= info.size.y || gid.z >= info.size.z) {
return;
}
int src_offset = gid.x * info.srcStride.x + gid.y * info.srcStride.y + gid.z * info.srcStride.z;
int dst_offset = gid.x * info.dstStride.x + gid.y * info.dstStride.y + gid.z * info.dstStride.z;
if(src_offset >= 0 && src_offset < info.totalSize.x) {
if(dst_offset >= 0 && dst_offset < info.totalSize.y) {
out[dst_offset] = in[src_offset];
}
}
}
)metal";
class MetalGather : public MetalExecution {
private:
const LoopParam* mLoop;
bool mNeedInit = false;
std::pair<MTLSize, MTLSize> mInitThreads;
id<MTLBuffer> mParam;
id<MTLComputePipelineState> mPipeline;
id<MTLComputePipelineState> mInitPipeline;
id<MTLBuffer> mInitParam;
std::vector<Tensor*> mTensors;
public:
MetalGather(const LoopParam* loop, Backend *bn, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) : MetalExecution(bn) {
mLoop = loop;
auto mtbn = static_cast<MetalBackend *>(bn);
auto context = (__bridge MNNMetalContext *)mtbn->context();
mParam = [context newDeviceBuffer:sizeof(GatherInfo) access:CPUWriteOnly];
bool useFp16 = mtbn->useFp16InsteadFp32();
mTensors.resize(mLoop->tensorNumber());
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
_setTensorStack(mTensors, inputs, outputs, mLoop);
auto dstTensor = mTensors[cmd->indexes()->data()[0]];
NSString* T = MetalCast::getScalarType(dstTensor->getType(), useFp16);
std::vector<std::string> keys = {
std::string([T UTF8String]),
"blitregion"
};
auto pipeline = mtbn->runtime()->findPipeline(keys);
if (nil == pipeline) {
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
compileOptions.preprocessorMacros = @{
@"T" : T,
// gather blit command pipeline
{
std::vector<std::string> keys = {
std::string([T UTF8String]),
"blitregion"
};
pipeline = mtbn->makeComputePipelineWithSourceOption(gBlitRegion, "main0", compileOptions);
mtbn->runtime()->insertPipeline(keys, pipeline);
auto pipeline = mtbn->runtime()->findPipeline(keys);
if (nil == pipeline) {
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
compileOptions.preprocessorMacros = @{
@"T" : T,
};
pipeline = mtbn->makeComputePipelineWithSourceOption(gBlitRegion, "gather_blit", compileOptions);
mtbn->runtime()->insertPipeline(keys, pipeline);
}
if (nil == pipeline) {
MNN_ERROR("Create gather pipeline error\n");
}
mPipeline = pipeline;
}
if (nil == pipeline) {
MNN_ERROR("Create gather pipeline error\n");
// scatter need init command pipeline
if(mLoop->initCommand() != nullptr){
mNeedInit = true;
std::string shader = "set_copy";
auto cmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
if (cmd->op() == nullptr){
shader = "set_zero";
} else {
mInitParam = [context newDeviceBuffer:sizeof(InitInfo) access:CPUWriteOnly];
}
std::vector<std::string> keys = {
std::string([T UTF8String]),
"init_region",
shader
};
auto pipeline = mtbn->runtime()->findPipeline(keys);
if (nil == pipeline) {
MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init];
compileOptions.preprocessorMacros = @{
@"T" : T,
};
pipeline = mtbn->makeComputePipelineWithSourceOption(gInitRegion, shader.c_str(), compileOptions);
mtbn->runtime()->insertPipeline(keys, pipeline);
}
if (nil == pipeline) {
MNN_ERROR("Create gather init pipeline error\n");
}
mInitPipeline = pipeline;
}
mPipeline = pipeline;
}
virtual ~MetalGather() = default;
virtual ErrorCode onResize(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs) override {
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
_setTensorStack(mTensors, inputs, outputs, mLoop);
auto srcStride = cmd->view()->GetAs<View>(1)->stride()->data();
auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
auto size = cmd->size()->data();
@ -401,22 +496,70 @@ public:
if (iterIndex[1] >= 0) {
param->iter[1] = 1;
}
auto dstTensor = mTensors[cmd->indexes()->data()[0]];
auto srcTensor = mTensors[cmd->indexes()->data()[1]];
auto inputSize = srcTensor->usize() / srcTensor->buffer().type.bytes();
auto outputSize = dstTensor->usize() / dstTensor->buffer().type.bytes();
param->totalSize[0] = inputSize;
param->totalSize[1] = outputSize;
if(mNeedInit) {
auto initCmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
auto data = reinterpret_cast<InitInfo*>([mInitParam contents]);
auto srcStride = initCmd->view()->GetAs<View>(1)->stride()->data();
auto dstStride = initCmd->view()->GetAs<View>(0)->stride()->data();
auto dataSize = initCmd->size()->data();
for (int i = 0; i < 3; ++i) {
data->srcStride[i] = srcStride[i];
data->dstStride[i] = dstStride[i];
data->size[i] = dataSize[i];
}
auto initDstTensor = mTensors[initCmd->indexes()->data()[0]];
auto initSrcTensor = mTensors[initCmd->indexes()->data()[1]];
auto initInputSize = initSrcTensor->usize() / initSrcTensor->buffer().type.bytes();
auto initOutputSize = initDstTensor->usize() / initDstTensor->buffer().type.bytes();
data->totalSize[0] = initInputSize;
data->totalSize[1] = initOutputSize;
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
mInitThreads = [context computeBestGroupAndLocal:mInitPipeline threads:MTLSizeMake(data->size[0], data->size[1], data->size[2])];
}
return NO_ERROR;
}
virtual void onEncode(const std::vector<Tensor *>& inputs, const std::vector<Tensor *>& outputs,
id<MTLComputeCommandEncoder> encoder) override {
if(mNeedInit) {
auto initCmd = mLoop->initCommand()->GetAs<RegionCommand>(0);
int x = initCmd->size()->data()[0];
int y = initCmd->size()->data()[1];
int z = initCmd->size()->data()[2];
[encoder setComputePipelineState:mInitPipeline];
auto dstTensor = mTensors[initCmd->indexes()->data()[0]];
auto srcTensor = mTensors[initCmd->indexes()->data()[1]];
MetalBackend::setTensor(dstTensor, encoder, 0);
MetalBackend::setTensor(srcTensor, encoder, 1);
[encoder setBuffer:mInitParam offset:0 atIndex:2];
[encoder dispatchThreadgroups:mInitThreads.first threadsPerThreadgroup:mInitThreads.second];
}
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
auto size = cmd->size()->data();
auto srcStride = cmd->view()->GetAs<View>(1)->stride()->data();
auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
int totalSize = mLoop->loopNumber() * size[0] * size[1] * size[2];
[encoder setComputePipelineState:mPipeline];
auto dstTensor = mTensors[cmd->indexes()->data()[0]];
auto srcTensor = mTensors[cmd->indexes()->data()[1]];
MetalBackend::setTensor(dstTensor, encoder, 0);
MetalBackend::setTensor(srcTensor, encoder, 1);
auto iterIndex = cmd->iterIndexes()->data();
if (iterIndex[0] >= 0) {
MetalBackend::setTensor(mTensors[iterIndex[0]], encoder, 3);
@ -452,7 +595,7 @@ int computeVec4dot(thread const int4& a, thread const int4& b)
return (((a.x * b.x) + (a.y * b.y)) + (a.z * b.z)) + (a.w * b.w);
}
kernel void main0(device T1* uOutput [[buffer(0)]], const device T0* uInput0 [[buffer(1)]], const device T0* uInput1 [[buffer(2)]], constant constBuffer& uConstant [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
kernel void loop_binary(device T1* uOutput [[buffer(0)]], const device T0* uInput0 [[buffer(1)]], const device T0* uInput1 [[buffer(2)]], constant constBuffer& uConstant [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
int3 posTmp = int3(gl_GlobalInvocationID);
if (posTmp.x < uConstant.size.w)
@ -515,7 +658,7 @@ public:
@"T1" : T1,
@"CUSTOM" : CUSTOM,
};
pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryBroadcast, "main0", compileOptions);
pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryBroadcast, "loop_binary", compileOptions);
mtbn->runtime()->insertPipeline(keys, pipeline);
}
if (nil == pipeline) {
@ -577,9 +720,7 @@ public:
if (nullptr == loop || loop->commands() == nullptr) {
return nullptr;
}
if (nullptr != loop->initCommand()) {
return nullptr;
}
// Make Tensor Stack
if (1 == loop->commands()->size()) {
auto cmd = loop->commands()->GetAs<RegionCommand>(0);
@ -587,10 +728,10 @@ public:
if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0) {
return new MetalGather(loop, bn, inputs, outputs);
}
if (OpType_MatMul == subop->type() && loop->parallel()) {
if (OpType_MatMul == subop->type() && loop->parallel() && nullptr == loop->initCommand()) {
return new MetalBatchMatMul(loop, bn);
}
if (OpType_BinaryOp == subop->type() && cmd->fuse() < 0 && 1 == loop->loopNumber()) {
if (OpType_BinaryOp == subop->type() && cmd->fuse() < 0 && 1 == loop->loopNumber() && nullptr == loop->initCommand()) {
std::vector<MNN::Tensor*> tensors(loop->tensorNumber());
_setTensorStack(tensors, inputs, outputs, loop);
auto srcTensor = tensors[cmd->indexes()->data()[1]];

View File

@ -466,7 +466,7 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp
}
bool OpenCLBackend::onSelectDynamicAllocator(int index, int maxIndex) {
if (mUseRecordQueue && false == mDevideOpRecord){
if (mUseRecordQueue && false == mDeviceOpRecord){
return false;
}
if (maxIndex > 2) {
@ -525,7 +525,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
return NULL;
}
if (iter == creators->end()) {
mDevideOpRecord = true;
mDeviceOpRecord = true;
#ifdef OPENCL_FALLBACK_LOG
if (nullptr != op->name()) {
MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mMemType, op->name()->c_str());
@ -565,7 +565,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
}
if (!valid) {
mDevideOpRecord = true;
mDeviceOpRecord = true;
#ifdef OPENCL_FALLBACK_LOG
for (auto t : inputs) {
auto tensorShape = OpenCL::tensorShapeFormat(t);
@ -589,7 +589,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
auto exe = iter->second->onCreate(inputs, outputs, op, this);
if (NULL == exe) {
mDevideOpRecord = true;
mDeviceOpRecord = true;
#ifdef OPENCL_FALLBACK_LOG
if (nullptr != op->name()) {
MNN_PRINT("The Creator Don't support type %s, memObject:%d, %s\n", MNN::EnumNameOpType(op->type()), mMemType, op->name()->c_str());
@ -1232,7 +1232,7 @@ int OpenCLBackend::fpBytes() {
void OpenCLBackend::clearRecord() const{
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(mUseRecordQueue && mDevideOpRecord){
if(mUseRecordQueue && mDeviceOpRecord){
for(int i = 0; i < mRecordings.size(); ++i){
std::vector<cl_array_arg_qcom> update_kernel_args;
std::vector<cl_workgroup_qcom> update_global_size;
@ -1263,7 +1263,7 @@ void OpenCLBackend::clearRecord() const{
void OpenCLBackend::enqeueRecord() const{
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(mUseRecordQueue && !mDevideOpRecord){
if(mUseRecordQueue && !mDeviceOpRecord){
for(int i = 0; i < mRecordings.size(); ++i){
std::vector<cl_array_arg_qcom> update_kernel_args;
std::vector<cl_workgroup_qcom> update_global_size;
@ -1290,7 +1290,7 @@ void OpenCLBackend::enqeueRecord() const{
void OpenCLBackend::releaseRecord(){
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
if(mUseRecordQueue && !mDevideOpRecord){
if(mUseRecordQueue && !mDeviceOpRecord){
for(int i = 0; i < mRecordings.size(); ++i){
cl_int res = clReleaseRecordingQCOM(mRecordings[i].record);
MNN_CHECK_CL_SUCCESS(res, "clReleaseRecordingQCOM");
@ -1309,7 +1309,7 @@ void OpenCLBackend::startRecord(cl_recording_qcom &recording){
MNN_PRINT("start startRecord !\n");
#endif
cl_int res = CL_SUCCESS;
if(mDevideOpRecord){
if(mDeviceOpRecord){
if(recording != NULL){
clReleaseRecordingQCOM(recording);
}
@ -1330,7 +1330,7 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){
#ifdef LOG_VERBOSE
MNN_PRINT("start endRecord !\n");
#endif
if(mDevideOpRecord){
if(mDeviceOpRecord){
cl_int res = CL_SUCCESS;
res = clEndRecordingQCOM(recording);
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
@ -1349,7 +1349,7 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){
}
void OpenCLBackend::addRecord(cl_recording_qcom &record, std::vector<RecordUpdateInfo *>updateInfo){
if(mDevideOpRecord){
if(mDeviceOpRecord){
RecordInfo info;
info.record = record;
for(int i = 0; i < updateInfo.size(); ++i) {
@ -1369,7 +1369,7 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, c
MNN_PRINT("start record2dKernel !\n");
#endif
cl_int res = CL_SUCCESS;
if(!mDevideOpRecord){
if(!mDeviceOpRecord){
RecordInfo info;
int recordNum = mRecordNums == mUseRecordableQueueSize ? 0 : mRecordNums;
if(updateInfo != nullptr){
@ -1439,7 +1439,7 @@ void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, c
for (size_t i = 0; i < 3; ++i) {
internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i]));
}
if(!mDevideOpRecord){
if(!mDeviceOpRecord){
RecordInfo info;
int recordNum = mRecordNums == mUseRecordableQueueSize ? 0 : mRecordNums;
if(updateInfo != nullptr){
@ -1547,12 +1547,12 @@ void OpenCLBackend::setGpuMode(const int cl_mode_num) {
mUseRecordQueue = ((cl_mode_num & MNN_GPU_RECORD_OP) || (cl_mode_num & MNN_GPU_RECORD_BATCH)) && mOpenCLRuntime->isSupportRecordQueue() && (mUseRecordableQueueSize > 0);
isSet = (cl_mode_num & MNN_GPU_RECORD_OP);
if(isSet) {
mDevideOpRecord = true;
mDeviceOpRecord = true;
totalSet++;
}
isSet = (cl_mode_num & MNN_GPU_RECORD_BATCH);
if(isSet) {
mDevideOpRecord = false;
mDeviceOpRecord = false;
totalSet++;
}
if(totalSet > 1) {

View File

@ -137,7 +137,7 @@ public:
return mUseRecordQueue;
}
bool isDevideOpRecord(){
return mDevideOpRecord;
return mDeviceOpRecord;
}
CLTuneLevel getCLTuneLevel() {
return mTuneLevel;
@ -185,7 +185,8 @@ private:
bool mIsCreateError{false};
mutable std::vector<RecordInfo> mRecordings;
bool mUseRecordQueue = false;
bool mDevideOpRecord = false;
bool mDeviceOpRecord = false;
friend class setRecordClose;
uint32_t mRecordNums = 0;
uint32_t mUseRecordableQueueSize;
private:
@ -200,6 +201,25 @@ private:
};
class setRecordClose{
public:
setRecordClose(OpenCLBackend *bn){
backend = bn;
if(backend->mUseRecordQueue){
backend->mUseRecordQueue = false;
needRecover = true;
}
}
~setRecordClose(){
if(needRecover){
backend->mUseRecordQueue = true;
}
}
private:
bool needRecover = false;
OpenCLBackend* backend;
};
template <class T>
class OpenCLCreatorRegister {
public:

View File

@ -633,20 +633,21 @@ bool localWSTune(const std::map<std::string, std::vector<std::pair<std::vector<u
return true;
}
bool getPreParamInfo(const std::string preParamName, uint32_t *preParamData, OpenCLRuntime *runtime){
auto& preParamInfo = runtime->preParamsMap();
if (preParamInfo.find(preParamName) != preParamInfo.end()) {
*preParamData = preParamInfo[preParamName];
bool getTunedInfo(const std::string kernelName, const std::vector<uint32_t> &gws, std::pair<std::vector<uint32_t>, uint32_t> &tuneInfo, OpenCLRuntime *runtime){
auto& tunedLws = runtime->tunedLwsMap();
auto& tuneLws = runtime->getTuneLwsMap();
std::pair<std::string, std::vector<uint32_t>> info = std::make_pair(kernelName, gws);
if (tunedLws.find(info) != tunedLws.end()) {
tuneInfo = tunedLws[info];
return true;
}
return false;
return localWSTune(tuneLws, gws, kernelName, tuneInfo);
}
void setPreParamInfo(const std::string preParamName, uint32_t preParamData, OpenCLRuntime *runtime){
auto& preParamInfo = runtime->preParamsMap();
if (preParamInfo.find(preParamName) == preParamInfo.end()) {
preParamInfo.insert(std::make_pair(preParamName, preParamData));
}
void setTunedInfo(const std::string kernelName, const std::vector<uint32_t> &gws, std::pair<std::vector<uint32_t>, uint32_t> &tuneInfo, OpenCLRuntime *runtime){
auto& tunedLws = runtime->tunedLwsMap();
std::pair<std::string, std::vector<uint32_t>> info = std::make_pair(kernelName, gws);
tunedLws.insert(std::make_pair(info, std::make_pair(tuneInfo.first, tuneInfo.second)));
}
} // namespace OpenCL

View File

@ -132,9 +132,9 @@ uint32_t get2DUseLocalMemTime(const std::vector<uint32_t> &gws, const std::vecto
std::pair<std::vector<uint32_t>, uint32_t> localWS2DDefault(const std::vector<uint32_t> &gws, const uint32_t maxWorkGroupSize,
OpenCLRuntime *runtime, const std::string &kernelName, const std::shared_ptr<KernelWrap> &mKernel, int tuneLevel);
bool getPreParamInfo(const std::string preParamName, uint32_t *preParamData, OpenCLRuntime *runtime);
bool getTunedInfo(const std::string kernelName, const std::vector<uint32_t> &gws, std::pair<std::vector<uint32_t>, uint32_t> &tuneInfo, OpenCLRuntime *runtime);
void setPreParamInfo(const std::string preParamName, uint32_t preParamData, OpenCLRuntime *runtime);
void setTunedInfo(const std::string kernelName, const std::vector<uint32_t> &gws, std::pair<std::vector<uint32_t>, uint32_t> &tuneInfo, OpenCLRuntime *runtime);
void copyBufferToImage(OpenCLRuntime *runtime, const cl::Buffer &buffer, const cl::Image &image, int w, int h, int precision);

View File

@ -348,10 +348,6 @@ unsigned int OpenCLRuntime::getQueueNum() {
return mQueueCount;
}
std::map<std::string, uint32_t>& OpenCLRuntime::preParamsMap(){
return mPreParams;
}
std::map<std::vector<uint32_t>, std::vector<uint32_t>>& OpenCLRuntime::tunedGemmParamsMap() {
return mTunedGemmParams;
}
@ -804,14 +800,6 @@ std::pair<const void*, size_t> OpenCLRuntime::makeCache(void* tuneInfo) {
backend->gemm.emplace_back(std::move(tuning));
}
// Get All PreParam cache
for(auto& iter : mPreParams){
std::unique_ptr<PreParamInfoT> info(new PreParamInfoT);
info->preParamName = iter.first;
info->preParamData = iter.second;
backend->preParam.emplace_back(std::move(info));
}
cache->backends.emplace_back(std::move(backend));
flatbuffers::FlatBufferBuilder builder;
@ -921,23 +909,18 @@ bool OpenCLRuntime::setCache(std::pair<const void*, size_t> cache) {
}
}
}
//Load PreParam Info
if(nullptr != backendinfo->preParam()){
auto preParamInfo = backendinfo->preParam();
for(int i = 0; i < preParamInfo->size(); ++i){
auto info = preParamInfo->GetAs<PreParamInfo>(i);
if (nullptr == info->preParamName()) {
MNN_ERROR("Error preParam info\n");
return false;
}
mPreParams.insert(std::make_pair(info->preParamName()->str(), info->preParamData()));
}
}
}
return true;
}
unsigned int OpenCLRuntime::getEventTime(cl::Event& event){
cl_int res = event.wait();
MNN_CHECK_CL_SUCCESS(res, "clEvent");
auto StartNanos = event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
auto StopNanos = event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
return (unsigned int)((StopNanos - StartNanos) / 1000.0);
}
void OpenCLRuntime::printEventTime(){
#ifdef ENABLE_OPENCL_TIME_PROFILER
if(mEvents.empty()){

View File

@ -128,6 +128,7 @@ public:
void pushEvent(std::pair<std::string, cl::Event> data) {
return mEvents.push_back(data);
}
unsigned int getEventTime(cl::Event& event);
void printEventTime();
void clearEvent(){
mKernelTime = 0;
@ -143,8 +144,6 @@ public:
unsigned int mKernelTime = 0;
std::map<std::string, uint32_t>& preParamsMap();
std::map<std::vector<uint32_t>, std::vector<uint32_t>>& tunedGemmParamsMap();
std::map<std::pair<std::string, std::vector<uint32_t>>, std::pair<std::vector<uint32_t>, uint32_t>>& tunedLwsMap();
@ -232,7 +231,6 @@ private:
double mStartNanos;
double mStopNanos;
std::map<std::string, uint32_t> mPreParams;
std::map<std::vector<uint32_t>, std::vector<uint32_t>> mTunedGemmParams;
std::map<std::pair<std::string, std::vector<uint32_t>>, std::pair<std::vector<uint32_t>, uint32_t>> mTunedLws;
std::map<std::string, std::vector<std::pair<std::vector<uint32_t>, std::pair<std::vector<uint32_t>, uint32_t>>>> mTuneLws;

View File

@ -17,7 +17,7 @@ KVCacheCLManager::KVCacheCLManager(Backend *backend, bool kv_cahce) : mKVCache(k
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
}
void KVCacheCLManager::allocKVCache(const KVMeta* meta, bool isDecodeResize) {
void KVCacheCLManager::allocKVCache(const KVMeta* meta) {
if (!mKVCache) {
return;
}
@ -25,10 +25,10 @@ void KVCacheCLManager::allocKVCache(const KVMeta* meta, bool isDecodeResize) {
if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){
mByte = 2;
}
reallocKVCache(meta, isDecodeResize);
reallocKVCache(meta, false);
}
bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, bool isDecodeResize) {
bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, bool isExecute) {
if (!mKVCache) {
return false;
}
@ -91,14 +91,14 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, bool isDecodeResize) {
mPastKey.reset(newKey);
mPastValue.reset(newValue);
// resize phase don't update mPastLength value, excute phase will update it
if(false == isDecodeResize){
if(isExecute){
mPastLength = start;
}
}
// Remove
// resize phase don't remove kvcache, excute phase will do it
if(false == isDecodeResize){
if(isExecute){
if (0 == meta->n_reserve) {
mPastLength = start;
return true;
@ -161,7 +161,7 @@ void AttentionBufExecution::handleKVCache(const std::vector<Tensor *> &inputs, c
int mask_seqlen = mask_shape[2];
int maskKvlen = mask_shape[3];
mKVCacheCLManager->setArgs(numHead, kvNumHead, headDim);
mKVCacheCLManager->allocKVCache(mMeta, mIsDecode);
mKVCacheCLManager->allocKVCache(mMeta);
mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4);
mDecodeTmpMaxlen = mKeyValueMaxlen;
mPastKvSeqlen = mKVCacheCLManager->pastKvLength();
@ -177,6 +177,9 @@ ErrorCode AttentionBufExecution::init() {
mRgQUpdateInfo.update_kernel_args.clear();
mRgQUpdateInfo.update_global_size.clear();
mRgQUpdateInfo.update_local_size.clear();
mRgMUpdateInfo.update_kernel_args.clear();
mRgMUpdateInfo.update_global_size.clear();
mRgMUpdateInfo.update_local_size.clear();
mRgUpdateInfo.update_kernel_args.clear();
mRgUpdateInfo.update_global_size.clear();
mRgUpdateInfo.update_local_size.clear();
@ -222,9 +225,36 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
mKVCacheCLManager->addKvLength(seqlen);
// prefill
if(mIsDecode == false){
int maskKvlen = mKvSeqlen;
int maskQlen = seqlen;
if(mHasMask) {
auto mask = inputs[3];
auto mask_shape = mask->shape();
maskQlen = mask_shape[2];
maskKvlen = mask_shape[3];
}
// key value static memory has been changed, need reset args
if(mKeyValueMaxlen != ROUND_UP(mKVCacheCLManager->maxLength(), 4)){
mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4);
}
if(false == mLongPrefill){
mGlobalWorkSizeQk0 = UP_DIV(mKvSeqlen, 4);
mQkPrefillGlobal_size[1] = ROUND_UP(mGlobalWorkSizeQk0, std::max((uint32_t)1, mLocalWorkSizeQk[1]));
mGlobalWorkSizeQk[1] = mQkPrefillGlobal_size[1];
mTempQK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch}));
mTempSoftMax.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch}));
if(mIsAddMask) {
mTempMask.reset(Tensor::createDevice<float>({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch}));
} else {
mTempMask.reset(Tensor::createDevice<uint32_t>({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch}));
}
mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
}
#ifndef ENABLE_OPENCL_TIME_PROFILER
if(mOpenCLBackend->isUseRecordQueue()){
if(mLongPrefill){
@ -233,10 +263,22 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
}else{
mRgQUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQ.get())();
mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
mQkUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
mRgMUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempMask.get())();
mQkUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempQ.get())();
mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mKVCacheCLManager->key()))();
if(mHasMask){
mQkUpdateInfo.update_kernel_args[3].arg_value = &openCLDeferBuffer(mTempMask.get())();
mQkUpdateInfo.update_kernel_args[4].arg_value = &openCLDeferBuffer(mTempQK.get())();
}else{
mQkUpdateInfo.update_kernel_args[3].arg_value = &openCLDeferBuffer(mTempQK.get())();
}
mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQK.get())();
mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempSoftMax.get())();
mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
mQkvUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempSoftMax.get())();
mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))();
}
} else {
#endif
@ -256,40 +298,58 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
ret |= mKernel_rearrange->get().setArg(6, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k");
}
{
// matmul qk
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(4, *mKVCacheCLManager->key());
ret |= mKernel_qk->get().setArg(10, mKvSeqlen);
ret |= mKernel_qk->get().setArg(11, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qk_decode");
}
{
// softmax
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(7, mKvSeqlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg softmax");
}
{
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrangeV->get().setArg(4, *mKVCacheCLManager->value());
ret |= mKernel_rearrangeV->get().setArg(5, mPastKvSeqlen);
ret |= mKernel_rearrangeV->get().setArg(6, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_v");
}
// qk * value
{
cl_int ret = CL_SUCCESS;
ret |= mKernel_qkv->get().setArg(4, *mKVCacheCLManager->value());
ret |= mKernel_qkv->get().setArg(7, mKvSeqlen);
ret |= mKernel_qkv->get().setArg(8, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qkv_decode");
}
if(mHasMask){
// rearrange mask
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrangeMask->get().setArg(4, openCLDeferBuffer(mTempMask.get()));
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_mask_shortprefill");
}
{
// matmul qk
mGlobalWorkSizeQk = {static_cast<uint32_t>(UP_DIV(seqlen, 4)), static_cast<uint32_t>(UP_DIV(mKvSeqlen, 4)), static_cast<uint32_t>(numHead*batch)};
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(1, mGlobalWorkSizeQk0);
ret |= mKernel_qk->get().setArg(4, *mKVCacheCLManager->key());
if(mHasMask) {
ret |= mKernel_qk->get().setArg(5, openCLDeferBuffer(mTempMask.get()));
}
ret |= mKernel_qk->get().setArg(6, openCLDeferBuffer(mTempQK.get()));
ret |= mKernel_qk->get().setArg(10, mKvSeqlen);
ret |= mKernel_qk->get().setArg(11, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qk_decode");
mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2]));
}
{
// softmax
cl_int ret = CL_SUCCESS;
ret |= mKernel_softmax->get().setArg(3, openCLDeferBuffer(mTempQK.get()));
ret |= mKernel_softmax->get().setArg(4, openCLDeferBuffer(mTempSoftMax.get()));
ret |= mKernel_softmax->get().setArg(7, mKvSeqlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg softmax");
}
{
// rearrange value
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrangeV->get().setArg(4, *mKVCacheCLManager->value());
ret |= mKernel_rearrangeV->get().setArg(5, mPastKvSeqlen);
ret |= mKernel_rearrangeV->get().setArg(6, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_v");
}
// qk * value
{
cl_int ret = CL_SUCCESS;
ret |= mKernel_qkv->get().setArg(3, openCLDeferBuffer(mTempSoftMax.get()));
ret |= mKernel_qkv->get().setArg(4, *mKVCacheCLManager->value());
ret |= mKernel_qkv->get().setArg(7, mKvSeqlen);
ret |= mKernel_qkv->get().setArg(8, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qkv_decode");
}
#ifndef ENABLE_OPENCL_TIME_PROFILER
}
#endif
#ifndef ENABLE_OPENCL_TIME_PROFILER
}
#endif
return NO_ERROR;
}
@ -882,20 +942,30 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
float scale = 1.0 / sqrt(headDim);
int maskKvlen = mKvSeqlen;
int maskQlen = seqlen;
if(mHasMask) {
auto mask = inputs[3];
auto mask_shape = mask->shape();
maskQlen = mask_shape[2];
maskKvlen = mask_shape[3];
if(mIsAddMask) {
mTempMask.reset(Tensor::createDevice<float>({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch}));
} else {
mTempMask.reset(Tensor::createDevice<uint32_t>({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch}));
}
}
mTempQ.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch}));
mTempQK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch}));
mTempSoftMax.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch}));
mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION);
if(mHasMask){
mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION);
}
cl::Buffer keyBuffer, valueBuffer;
if(mNeedKvCache) {
@ -911,9 +981,12 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
keyBuffer = openCLBuffer(mTempK.get());
valueBuffer = openCLBuffer(mTempV.get());
}
mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
if(mHasMask){
mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION);
}
{
// rearrange query
@ -932,7 +1005,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
ret |= mKernel_rearrangeQ->get().setArg(index++, mGlobalWorkSizeRearrgQ[1]);
ret |= mKernel_rearrangeQ->get().setArg(index++, mGlobalWorkSizeRearrgQ[2]);
ret |= mKernel_rearrangeQ->get().setArg(index++, openCLBuffer(query));
ret |= mKernel_rearrangeQ->get().setArg(index++, openCLBuffer(mTempQ.get()));
ret |= mKernel_rearrangeQ->get().setArg(index++, openCLDeferBuffer(mTempQ.get()));
ret |= mKernel_rearrangeQ->get().setArg(index++, seqlen);
ret |= mKernel_rearrangeQ->get().setArg(index++, headDim);
ret |= mKernel_rearrangeQ->get().setArg(index++, numHead);
@ -942,8 +1015,9 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
mGlobalWorkSizeRearrgQ[0] = ROUND_UP(mGlobalWorkSizeRearrgQ[0], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[0]));
mGlobalWorkSizeRearrgQ[1] = ROUND_UP(mGlobalWorkSizeRearrgQ[1], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[1]));
mGlobalWorkSizeRearrgQ[2] = ROUND_UP(mGlobalWorkSizeRearrgQ[2], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[2]));
mRgQUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempQ.get())()});
mOpRecordUpdateInfo.emplace_back(&mRgQUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ);
mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, &mRgQUpdateInfo);
}
{
// rearrange key
@ -984,6 +1058,34 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, &mRgUpdateInfo);
}
if (mHasMask){
std::set<std::string> buildOption;
if(mIsAddMask){
buildOption.emplace("-DADD_MASK");
} else if(mHasMask) {
buildOption.emplace("-DSET_MASK");
}
mKernel_rearrangeMask = runtime->buildKernel("attention_buf", "rearrange_mask_shortprefill", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
mGlobalWorkSizeRearrgM = {static_cast<uint32_t>(UP_DIV(maskQlen, 4)), static_cast<uint32_t>(UP_DIV(maskKvlen, 4)), static_cast<uint32_t>(batch)};
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrangeMask));
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[0]);
ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[1]);
ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[2]);
ret |= mKernel_rearrangeMask->get().setArg(index++, openCLBuffer(inputs[3]));
ret |= mKernel_rearrangeMask->get().setArg(index++, openCLDeferBuffer(mTempMask.get()));
ret |= mKernel_rearrangeMask->get().setArg(index++, maskQlen);
ret |= mKernel_rearrangeMask->get().setArg(index++, maskKvlen);
MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask_shortprefill");
mLocalWorkSizeRearrgM = localWS3DDefault(mGlobalWorkSizeRearrgM, maxWorkGroupSize, runtime, "rearrange_mask_shortprefill", mKernel_rearrangeMask, mOpenCLBackend->getCLTuneLevel()).first;
mGlobalWorkSizeRearrgM[0] = ROUND_UP(mGlobalWorkSizeRearrgM[0], std::max((uint32_t)1, mLocalWorkSizeRearrgM[0]));
mGlobalWorkSizeRearrgM[1] = ROUND_UP(mGlobalWorkSizeRearrgM[1], std::max((uint32_t)1, mLocalWorkSizeRearrgM[1]));
mGlobalWorkSizeRearrgM[2] = ROUND_UP(mGlobalWorkSizeRearrgM[2], std::max((uint32_t)1, mLocalWorkSizeRearrgM[2]));
mRgMUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempMask.get())()});
mOpRecordUpdateInfo.emplace_back(&mRgMUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, &mRgMUpdateInfo);
}
{
// matmul qk
std::set<std::string> buildOption;
@ -1002,12 +1104,12 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[2]);
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mTempQ.get()));
ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQ.get()));
ret |= mKernel_qk->get().setArg(index++, keyBuffer);
if(mHasMask) {
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(inputs[3]));
ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempMask.get()));
}
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mTempQK.get()));
ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQK.get()));
ret |= mKernel_qk->get().setArg(index++, scale);
ret |= mKernel_qk->get().setArg(index++, seqlen);
ret |= mKernel_qk->get().setArg(index++, maskKvlen);
@ -1021,16 +1123,25 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2]));
mQkUpdateInfo.update_kernel_args.push_back({0, 1, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0});
mQkUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQ.get())()});
if(mNeedKvCache) {
mQkUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
}
if(mHasMask){
mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLDeferBuffer(mTempMask.get())()});
mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKvSeqlen), &mKvSeqlen});
mQkUpdateInfo.update_kernel_args.push_back({0, 11, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
}else{
mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
mQkUpdateInfo.update_kernel_args.push_back({0, 9, sizeof(mKvSeqlen), &mKvSeqlen});
mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
}
mQkPrefillGlobal_size[0] = mGlobalWorkSizeQk[0];
mQkPrefillGlobal_size[1] = mGlobalWorkSizeQk[1];
mQkPrefillGlobal_size[2] = mGlobalWorkSizeQk[2];
mQkUpdateInfo.update_global_size.push_back({0, mQkPrefillGlobal_size});
mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo);
}
@ -1051,8 +1162,8 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]);
ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempQK.get()));
ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempSoftMax.get()));
ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempQK.get()));
ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
ret |= mKernel_softmax->get().setArg(index++, inside);
ret |= mKernel_softmax->get().setArg(index++, outside);
ret |= mKernel_softmax->get().setArg(index++, mKvSeqlen);
@ -1065,9 +1176,11 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen});
mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax);
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo);
}
{
// rearrange value
@ -1120,7 +1233,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]);
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(mTempSoftMax.get()));
ret |= mKernel_qkv->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
ret |= mKernel_qkv->get().setArg(index++, valueBuffer);
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0]));
ret |= mKernel_qkv->get().setArg(index++, seqlen);
@ -1135,6 +1248,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector<Tensor *> &inpu
mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0]));
mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1]));
mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2]));
mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()});
if(mNeedKvCache) {
mQkvUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
}
@ -1164,15 +1278,7 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector<Tensor *> &input
int group_size = numHead / kvNumHead;
float scale = 1.0 / sqrt(headDim);
int mask_seqlen = seqlen;
int mask_kvlen = seqlen;
if(mHasMask) {
auto mask = inputs[3];
auto mask_shape = mask->shape();
mask_seqlen = mask_shape[2];
mask_kvlen = mask_shape[3];
}
cl::Buffer keyBuffer, valueBuffer;
if(mNeedKvCache) {
keyBuffer = *mKVCacheCLManager->key();
@ -1440,13 +1546,17 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector<Tensor *> &input
ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
mOpenCLBackend->startRecord(mRecording);
auto shape = inputs[0]->shape();
int batch = shape[0];
int seqlen = shape[1];
int numHead = shape[2];
int headDim = shape[3];
int kvNumHead = inputs[1]->shape()[2];
if(mNeedKvCache) {
// if has kv_cache, default has mask
MNN_ASSERT(inputs.size() > 3);
}
mHasMask = inputs.size() > 3;
mIsDecode = seqlen == 1;
mIsDecode = seqlen == 1 && mMeta->add == 1;
// reset updateArgs variable and kernel vector
init();
@ -1457,22 +1567,99 @@ ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor *> &inputs, c
if(mIsDecode) {
return decodeResize(inputs, outputs);
} else {
if(seqlen > 512 && mPastKvSeqlen == 0){
mLongPrefill = true;
return longPrefillResize(inputs, outputs);
if(mPastKvSeqlen == 0){
std::pair<std::vector<uint32_t>, uint32_t> tuneInfo;
std::string info = "attention_" + std::to_string(batch) + "_" + std::to_string(numHead) + "_" + std::to_string(headDim) + "_" + std::to_string(kvNumHead);
if(seqlen > 16){
if(getTunedInfo(info, {static_cast<unsigned int>(seqlen)}, tuneInfo, mOpenCLBackend->getOpenCLRuntime())){
mLongPrefill = tuneInfo.first[0];
} else{
if (mOpenCLBackend->getCLTuneLevel() == Heavy || mOpenCLBackend->getCLTuneLevel() == Wide){
setRecordClose closeRecord(mOpenCLBackend);
// tunning choose use witch preill
prefillResize(inputs, outputs);
auto shortPrefillTime = getExecuteTime();
init();
mLongPrefill = true;
longPrefillResize(inputs, outputs);
auto longPrefillTime = getExecuteTime();
mLongPrefill = false;
if(longPrefillTime < shortPrefillTime){
mLongPrefill = true;
}
std::pair<std::vector<uint32_t>, uint32_t> tuneInfoTmp = std::make_pair<std::vector<uint32_t>, uint32_t>({mLongPrefill}, 0);
setTunedInfo(info, {static_cast<unsigned int>(seqlen)}, tuneInfoTmp, mOpenCLBackend->getOpenCLRuntime());
init();
}else{
if(seqlen > 512){
mLongPrefill = true;
}
}
}
}
}
if(mLongPrefill){
longPrefillResize(inputs, outputs);
}else{
return prefillResize(inputs, outputs);
prefillResize(inputs, outputs);
}
}
return NO_ERROR;
}
int AttentionBufExecution::getExecuteTime(){
int executeTime = 0;
auto runtime = mOpenCLBackend->getOpenCLRuntime();
if(mLongPrefill) {
int seq_idx = 0;
cl::Event event0, event1, event2, event3, event4, event5, event6;
run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event0);
executeTime += runtime->getEventTime(event0);
if(mHasMask) {
run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event1);
executeTime += runtime->getEventTime(event1);
}
for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) {
run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event2);
executeTime += runtime->getEventTime(event2);
run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event3);
executeTime += runtime->getEventTime(event3);
run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event4);
executeTime += runtime->getEventTime(event4);
run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event5);
executeTime += runtime->getEventTime(event5);
}
seq_idx = 0;
run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event6);
executeTime += runtime->getEventTime(event6);
} else{
cl::Event event0, event1, event2, event3, event4, event5, event6;
run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime(), &event0);
executeTime += runtime->getEventTime(event0);
run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event1);
executeTime += runtime->getEventTime(event1);
if(mHasMask) {
run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime(), &event2);
executeTime += runtime->getEventTime(event2);
}
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event3);
executeTime += runtime->getEventTime(event3);
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event4);
executeTime += runtime->getEventTime(event4);
run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event5);
executeTime += runtime->getEventTime(event5);
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event6);
executeTime += runtime->getEventTime(event6);
}
return executeTime;
}
ErrorCode AttentionBufExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
#ifdef LOG_VERBOSE
MNN_PRINT("start AttentionBufExecution onExecute !\n");
#endif
if(mNeedKvCache && mIsDecode){
if(mNeedKvCache){
mKVCacheCLManager->reallocKVCache(mMeta);
}
UpdateArgs(inputs, outputs);
@ -1513,19 +1700,23 @@ ErrorCode AttentionBufExecution::onExecute(const std::vector<Tensor *> &inputs,
runKernel2D(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event4);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event4});
}else{
cl::Event event0, event1, event2, event3, event4, event5;
cl::Event event0, event1, event2, event3, event4, event5, event6;
run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime(), &event0);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_q", event0});
run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event1);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_k", event1});
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event2);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event2});
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event3);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event3});
run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event4);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_v", event4});
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event5);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event5});
if(mHasMask) {
run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime(), &event2);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_mask_shortprefill", event2});
}
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event3);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event3});
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event4);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event4});
run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event5);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_v", event5});
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event6);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event6});
}
}
#else
@ -1562,6 +1753,9 @@ ErrorCode AttentionBufExecution::onExecute(const std::vector<Tensor *> &inputs,
}else{
run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime());
if(mHasMask) {
run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime());
}
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime());

View File

@ -22,8 +22,8 @@ public:
KVCacheCLManager(Backend *backend, bool kv_cache);
~KVCacheCLManager() = default;
void allocKVCache(const KVMeta* meta, bool isDecodeResize = false);
bool reallocKVCache(const KVMeta* meta, bool isDecodeResize = false);
void allocKVCache(const KVMeta* meta);
bool reallocKVCache(const KVMeta* meta, bool isExecute = true);
void setArgs(int numHead, int kvNumHead, int headDim){
mNumHead = numHead;
mKvNumHead = kvNumHead;
@ -67,6 +67,7 @@ public:
ErrorCode UpdateArgs(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
ErrorCode init();
int getExecuteTime();
virtual ~AttentionBufExecution() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
@ -88,12 +89,14 @@ private:
OpenCLBackend *mOpenCLBackend;
RecordUpdateInfo mRgUpdateInfo;
RecordUpdateInfo mRgQUpdateInfo;
RecordUpdateInfo mRgMUpdateInfo;
RecordUpdateInfo mQkUpdateInfo;
RecordUpdateInfo mSoftMaxUpdateInfo;
RecordUpdateInfo mRgVUpdateInfo;
RecordUpdateInfo mQkvUpdateInfo;
int mGlobalWorkSizeQk0 = 0;
size_t mQkGlobal_size[2];
size_t mQkPrefillGlobal_size[3];
std::vector<RecordUpdateInfo*> mOpRecordUpdateInfo;
std::shared_ptr<KVCacheCLManager> mKVCacheCLManager;
std::shared_ptr<Tensor> mTempQK, mTempSoftMax;
@ -131,6 +134,7 @@ private:
private:
std::shared_ptr<KernelWrap> mKernel_rearrangeQ;
std::shared_ptr<KernelWrap> mKernel_rearrangeV;
std::shared_ptr<KernelWrap> mKernel_rearrangeMask;
std::shared_ptr<KernelWrap> mKernel_rearrange;
std::shared_ptr<KernelWrap> mKernel_qk;
std::shared_ptr<KernelWrap> mKernel_softmax;
@ -148,6 +152,8 @@ private:
std::vector<uint32_t> mLocalWorkSizeRearrgV;
std::vector<uint32_t> mGlobalWorkSizeRearrg;
std::vector<uint32_t> mLocalWorkSizeRearrg;
std::vector<uint32_t> mGlobalWorkSizeRearrgM;
std::vector<uint32_t> mLocalWorkSizeRearrgM;
};
} // namespace OpenCL

View File

@ -614,15 +614,12 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu
const int blockDim = mResource->mInputChannel / mResource->mBlockSize;
bool useLocalMem = inputChannels >= 32;
std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel);
std::string kernelName = "gemv_conv_c8";
std::set<std::string> buildOption = mResource->mBuildOptions;
int inputChannelLeaves = 0;
if(mResource->mNumQuantBit == 4){
inputChannelLeaves = useLocalMem ? (inputChannels % 4) : (blockDim % 4);
kernelName += "_int4_buf";
} else {
inputChannelLeaves = useLocalMem ? (inputChannels % 2) : (blockDim % 2);
kernelName += "_int8_buf";
}
if(outChannel % 8 != 0){
buildOption.emplace("-DOUTPUT_CHANNEL_LEAVES");
@ -638,7 +635,7 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu
for (int ksize = 8; ksize <= 256; ksize*=2) {
auto option = buildOption;
option.emplace("-DWGS=" + std::to_string(ksize));
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName, option, mOpenCLBackend->getPrecision());
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel));
std::vector<uint32_t> gws = {static_cast<uint32_t>(ksize), static_cast<uint32_t>(UP_DIV(outChannel, 8))};
std::vector<uint32_t> lws = {static_cast<uint32_t>(ksize), 1};
@ -646,6 +643,7 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu
cl_int ret = CL_SUCCESS;
ret |= kernel->get().setArg(idx++, static_cast<int>(gws[0]));
ret |= kernel->get().setArg(idx++, static_cast<int>(gws[1]));
ret |= kernel->get().setArg(idx++, static_cast<int>(gws[1]));
ret |= kernel->get().setArg(idx++, openCLBuffer(input));
if(mResource->mUseImage){
ret |= kernel->get().setArg(idx++, *mResource->mKernelImage.get());
@ -657,13 +655,15 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu
ret |= kernel->get().setArg(idx++, openCLBuffer(output));
ret |= kernel->get().setArg(idx++, static_cast<int>(outputChannelBlocks));
ret |= kernel->get().setArg(idx++, static_cast<int>(inputChannelBlocks));
ret |= kernel->get().setArg(idx++, static_cast<int>(outputChannelBlocks));
ret |= kernel->get().setArg(idx++, static_cast<int>(inputChannelBlocks));
ret |= kernel->get().setArg(idx++, inputChannels);
ret |= kernel->get().setArg(idx++, static_cast<int>(blockNum));
ret |= kernel->get().setArg(idx++, static_cast<int>(blockDim));
ret |= kernel->get().setArg(idx++, static_cast<float>(mResource->mCoef));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf Kernel Select");
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf Kernel Select");
std::pair<std::vector<uint32_t>, int> retTune;
int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), kernelName + info, kernel);
int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info, kernel);
if(min_time > cost_time) {
local_size = ksize;
min_time = cost_time;
@ -673,12 +673,13 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu
buildOption.emplace("-DWGS=" + std::to_string(local_size));
mGlobalWorkSize = {static_cast<uint32_t>(local_size), static_cast<uint32_t>(UP_DIV(outChannel, 8))};
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName, buildOption, mOpenCLBackend->getPrecision());
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel));
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(mGlobalWorkSize[0]));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(mGlobalWorkSize[1]));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(mGlobalWorkSize[1]));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input));
if(mResource->mUseImage){
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get());
@ -690,15 +691,17 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(outputChannelBlocks));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannelBlocks));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(outputChannelBlocks));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannelBlocks));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannels));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(blockNum));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(blockDim));
ret |= unit.kernel->get().setArg(idx++, static_cast<float>(mResource->mCoef));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c4_0_buf");
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf");
if(useLocalMem){
mLocalWorkSize = {static_cast<uint32_t>(local_size), 1};
}else{
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf", unit.kernel, mOpenCLBackend->getCLTuneLevel()).first;
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info, unit.kernel, mOpenCLBackend->getCLTuneLevel()).first;
}
mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]};
@ -747,8 +750,142 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, option, mOpenCLBackend->getPrecision());
}
std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel);
if(global_y <= 16) {
mUnits.resize(3);
int outputChannelAlign8 = ROUND_UP(outChannel, 8);
mConvGemmInpTensor.reset(Tensor::createDevice<float>({inputChannelAlign * ROUND_UP(global_y, 4)}));
mConvGemmOutTensor.reset(Tensor::createDevice<float>({outputChannelAlign8 * ROUND_UP(global_y, 4)}));
mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC);
{
//c4nhw4 -> nhwc
auto &unit = mUnits[0];
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_c4nhw4_to_nhwc", buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel));
mGlobalWorkSize = {static_cast<uint32_t>(UP_DIV(global_y, 4)), static_cast<uint32_t>(UP_DIV(inputChannels, 4))};
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get()));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(global_y));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannels));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannelAlign));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_c4nhw4_to_nhwc");
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemm_c4nhw4_to_nhwc", unit.kernel, mOpenCLBackend->getCLTuneLevel()).first;
mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]};
unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]};
}
{
const int inputChannelBlocks = UP_DIV(inputChannels, 4);
const int outputChannelBlocks = UP_DIV(outChannel, 4);
auto &unit = mUnits[1];
std::set<std::string> buildOption = mResource->mBuildOptions;
if(mResource->mUseImage){
buildOption.emplace("-DUSE_IMAGE");
}
buildOption.emplace("-DCOMPUTE_BATCH");
int local_size = 64;
if(mOpenCLBackend->getCLTuneLevel() != None && mOpenCLBackend->getCLTuneLevel() != Fast){
int min_time = INT_MAX;
for (int ksize = 16; ksize <= 256; ksize*=2) {
auto option = buildOption;
option.emplace("-DWGS=" + std::to_string(ksize));
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel));
std::vector<uint32_t> gws = {static_cast<uint32_t>(ksize), static_cast<uint32_t>(UP_DIV(outChannel, 8)), static_cast<uint32_t>(UP_DIV(global_y, 4))};
std::vector<uint32_t> lws = {static_cast<uint32_t>(ksize), 1, 1};
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= kernel->get().setArg(idx++, static_cast<int>(gws[0]));
ret |= kernel->get().setArg(idx++, static_cast<int>(gws[1]));
ret |= kernel->get().setArg(idx++, static_cast<int>(gws[2]));
ret |= kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get()));
if(mResource->mUseImage){
ret |= kernel->get().setArg(idx++, *mResource->mKernelImage.get());
}else{
ret |= kernel->get().setArg(idx++, *mResource->mKernelBuffer.get());
}
ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get()));
ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get()));
ret |= kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get()));
ret |= kernel->get().setArg(idx++, static_cast<int>(outputChannelAlign8));
ret |= kernel->get().setArg(idx++, static_cast<int>(inputChannelAlign));
ret |= kernel->get().setArg(idx++, static_cast<int>(outputChannelBlocks));
ret |= kernel->get().setArg(idx++, static_cast<int>(inputChannelBlocks));
ret |= kernel->get().setArg(idx++, inputChannels);
ret |= kernel->get().setArg(idx++, static_cast<int>(blockNum));
ret |= kernel->get().setArg(idx++, static_cast<int>(blockDim));
ret |= kernel->get().setArg(idx++, static_cast<float>(mResource->mCoef));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf Kernel Select");
std::pair<std::vector<uint32_t>, int> retTune;
int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info + "_batch", kernel);
if(min_time > cost_time) {
local_size = ksize;
min_time = cost_time;
}
}
}
buildOption.emplace("-DWGS=" + std::to_string(local_size));
mGlobalWorkSize = {static_cast<uint32_t>(local_size), static_cast<uint32_t>(UP_DIV(outChannel, 8)), static_cast<uint32_t>(UP_DIV(global_y, 4))};
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel));
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(mGlobalWorkSize[0]));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(mGlobalWorkSize[1]));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(mGlobalWorkSize[2]));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get()));
if(mResource->mUseImage){
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get());
}else{
ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get());
}
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get()));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get()));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get()));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(outputChannelAlign8));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannelAlign));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(outputChannelBlocks));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannelBlocks));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(inputChannels));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(blockNum));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(blockDim));
ret |= unit.kernel->get().setArg(idx++, static_cast<float>(mResource->mCoef));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf");
mLocalWorkSize = {static_cast<uint32_t>(local_size), 1, 1};
mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]};
unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]};
}
{
auto &unit = mUnits[2];
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_nhwc_to_c4nhw4", buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel));
mGlobalWorkSize = {static_cast<uint32_t>(UP_DIV(global_y, 4)), static_cast<uint32_t>(UP_DIV(outChannel, 4))};
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]);
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get()));
ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(global_y));
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(outputChannelAlign8));
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_nhwc_to_c4nhw4");
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemm_nhwc_to_c4nhw4", unit.kernel, mOpenCLBackend->getCLTuneLevel()).first;
mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]};
unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]};
}
return;
}
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel));
@ -773,7 +910,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu
ret |= unit.kernel->get().setArg(idx++, static_cast<int>(blockDim));
ret |= unit.kernel->get().setArg(idx++, mResource->mCoef);
MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv1x1_buf");
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel()).first;
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName + info, unit.kernel, mOpenCLBackend->getCLTuneLevel()).first;
mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize);
unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]};
unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]};
@ -814,13 +951,7 @@ ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector<Tensor *>
} else if (conv2dCommonParams->relu6()) {
mResource->mBuildOptions.emplace("-DRELU6");
}
if (mResource->mNumQuantBit == 8) {
// int8 case
mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8");
} else if (mResource->mNumQuantBit == 4){
// int4 case
mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4");
} else {/* More types to be supported. */}
mResource->mBuildOptions.emplace("-DQUANT_BIT=" + std::to_string(mResource->mNumQuantBit));
#ifdef LOG_VERBOSE
MNN_PRINT("end ConvBufLowMemoryExecution init !\n");
#endif
@ -870,11 +1001,35 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector<Tensor *> &input
if(batch == 1){
tuneGemvLowMemory(input, output);
} else {
// when batch is big, convert to float weight and do gemm computation in floating field
if(batch > 128){
std::pair<std::vector<uint32_t>, uint32_t> tuneInfo;
std::string info = "convBufLowMemory_" + std::to_string(mResource->mInputChannel) + "_" + std::to_string(mResource->mOutputChannel);
if(batch > 16){
if(getTunedInfo(info, {static_cast<unsigned int>(batch)}, tuneInfo, mOpenCLBackend->getOpenCLRuntime())){
mUseFPWeight = tuneInfo.first[0];
} else{
if((mOpenCLBackend->getCLTuneLevel() == Heavy || mOpenCLBackend->getCLTuneLevel() == Wide)){
setRecordClose closeRecord(mOpenCLBackend);
tuneGemmLowMemory(input, output);
auto shortBatchTime = getExecuteTime();
mUseFPWeight = true;
useFPWeightGemmLowMemory(input, output);
auto longBatchTime = getExecuteTime();
mUseFPWeight = false;
if(longBatchTime < shortBatchTime){
mUseFPWeight = true;
}
std::pair<std::vector<uint32_t>, uint32_t> tuneInfoTmp = std::make_pair<std::vector<uint32_t>, uint32_t>({mUseFPWeight}, 0);
setTunedInfo(info, {static_cast<unsigned int>(batch)}, tuneInfoTmp, mOpenCLBackend->getOpenCLRuntime());
} else{
if(batch > 512){
mUseFPWeight = true;
}
}
}
}
if(mUseFPWeight){
useFPWeightGemmLowMemory(input, output);
mUseFPWeight = true;
} else {
}else{
tuneGemmLowMemory(input, output);
}
}
@ -900,6 +1055,66 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector<Tensor *> &input
return NO_ERROR;
}
int ConvBufLowMemoryExecution::getExecuteTime(){
for (auto &unit : mUnits) {
bool lws_null = true;
for (size_t i = 0; i < unit.globalWorkSize.dimensions(); ++i) {
unit.globalWorkSize.get()[i] = ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i]));
if(unit.localWorkSize.get()[i] != 0) {
lws_null = false;
}
}
if(lws_null){
unit.localWorkSize = cl::NullRange;
}
}
int executeTime = 0;
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto res = CL_SUCCESS;
if(mUseFPWeight){
// arrange input and weight
int i = 0;
for (; i < 2; ++i){
auto unit = mUnits[i];
cl::Event event;
res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(),
cl::NullRange,
unit.globalWorkSize,
unit.localWorkSize,
nullptr,
&event);
executeTime += runtime->getEventTime(event);
}
// call gemm execute
executeTime += mStrassenComputor->getExecuteTime();
// rearrange output
for (; i < mUnits.size(); ++i){
auto unit = mUnits[i];
cl::Event event;
res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(),
cl::NullRange,
unit.globalWorkSize,
unit.localWorkSize,
nullptr,
&event);
executeTime += runtime->getEventTime(event);
}
}else{
for (auto &unit : mUnits) {
cl::Event event;
res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(),
cl::NullRange,
unit.globalWorkSize,
unit.localWorkSize,
nullptr,
&event);
executeTime += runtime->getEventTime(event);
}
}
return executeTime;
}
ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
#ifdef LOG_VERBOSE
MNN_PRINT("Start ConvBufLowMemoryExecution onExecute !\n");

View File

@ -25,6 +25,7 @@ public:
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
private:
int getExecuteTime();
void getInfoFromOpLowMemory(void *weight_ptr);
void set1x1WeightLowMemory();
void setGeneralWeightLowMemory();

View File

@ -283,7 +283,7 @@ LoopBinaryBufExecution::LoopBinaryBufExecution(const LoopParam *loop, const std:
: CommonExecution(bn, op) {
mLoop = loop;
mTensors.resize(mLoop->tensorNumber());
mBuildOptions.emplace("-DLOOP_BINARY_OPERATOR=" + compute);
mBuildOptions.emplace("-DOPERATOR=" + compute);
}
ErrorCode LoopBinaryBufExecution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {

View File

@ -460,6 +460,26 @@ ErrorCode StrassenMatrixComputor::onEncode(int e, int l, int h, int as, int bs,
return _generateMatMul(e, l, h, a, b, c, bias, 0, useBias);
}
int StrassenMatrixComputor::getExecuteTime() {
// All is done in onResize, just execute it
auto res = CL_SUCCESS;
int executeTime = 0;
for (auto &unit : mUnits) {
if(unit.localWorkSize[0] == 0 || unit.localWorkSize[1] == 0) {
unit.localWorkSize = cl::NullRange;
}
cl::Event event;
res = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueNDRangeKernel(unit.kernel->get(),
cl::NullRange,
unit.globalWorkSize,
unit.localWorkSize,
nullptr,
&event);
executeTime += mOpenCLBackend->getOpenCLRuntime()->getEventTime(event);
}
return executeTime;
}
void StrassenMatrixComputor::onExecute() {
// All is done in onResize, just execute it
auto res = CL_SUCCESS;

View File

@ -29,6 +29,7 @@ public:
ErrorCode onEncode(int e, int l, int h, int as, int bs, int cs, const cl::Buffer AT, const cl::Buffer BT, cl::Buffer CT, bool useBias, const cl::Buffer Bias);
int getExecuteTime();
void onExecute();
void onReset();

View File

@ -418,6 +418,74 @@ __kernel void rearrange_v(GLOBAL_SIZE_3_DIMS
#endif
}
__kernel void rearrange_mask_shortprefill(GLOBAL_SIZE_3_DIMS
#ifdef ADD_MASK
__global const FLOAT* mask,
__global FLOAT* maskout,
#else
__global const int* mask, // [1 1 query_seq_len mask_key_seq_len4]
__global int* maskout, // [1 1 mask_key_seq_len4 query_seq_len4]
#endif
__private const int query_seq_len,
__private const int mask_key_seq_len){
const int x = get_global_id(0); // query_seq_len4
const int y = get_global_id(1); // mask_key_seq_len4
const int z = get_global_id(2); // batch
DEAL_NON_UNIFORM_DIM3(x, y, z);
const int x4 = x << 2;
const int y4 = y << 2;
float4 mask_tmp0, mask_tmp1, mask_tmp2, mask_tmp3;
float4 mask0, mask1, mask2, mask3;
int mask_offset = x4 * mask_key_seq_len + y4;
if(x4 + 3 < query_seq_len && y4 + 3 < mask_key_seq_len){
mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
mask_tmp1 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
mask_tmp2 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
mask_tmp3 = convert_float4(vload4(0, mask + mask_offset));
} else{
if(y4 + 3 < mask_key_seq_len){
mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset));
} else if(y4 + 1 == mask_key_seq_len){
mask_tmp0 = (float4)(mask[mask_offset], 0, 0, 0); mask_offset += mask_key_seq_len;
mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], 0, 0, 0); mask_offset += mask_key_seq_len;
mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], 0, 0, 0); mask_offset += mask_key_seq_len;
mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], 0, 0, 0);
}else if(y4 + 2 == mask_key_seq_len){
mask_tmp0 = (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); mask_offset += mask_key_seq_len;
mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); mask_offset += mask_key_seq_len;
mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); mask_offset += mask_key_seq_len;
mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0);
}else if(y4 + 3 == mask_key_seq_len){
mask_tmp0 = (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); mask_offset += mask_key_seq_len;
mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); mask_offset += mask_key_seq_len;
mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); mask_offset += mask_key_seq_len;
mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0);
}
}
mask0 = (float4)(mask_tmp0.s0, mask_tmp1.s0, mask_tmp2.s0, mask_tmp3.s0);
mask1 = (float4)(mask_tmp0.s1, mask_tmp1.s1, mask_tmp2.s1, mask_tmp3.s1);
mask2 = (float4)(mask_tmp0.s2, mask_tmp1.s2, mask_tmp2.s2, mask_tmp3.s2);
mask3 = (float4)(mask_tmp0.s3, mask_tmp1.s3, mask_tmp2.s3, mask_tmp3.s3);
int query_seq_len4 = ((query_seq_len + 3) / 4) * 4;
int output_offset = y4 * query_seq_len4 + x4;
#ifdef ADD_MASK
vstore4(CONVERT_FLOAT4(mask0), 0, maskout + output_offset);
vstore4(CONVERT_FLOAT4(mask1), 0, maskout + output_offset + query_seq_len4);
vstore4(CONVERT_FLOAT4(mask2), 0, maskout + output_offset + query_seq_len4 + query_seq_len4);
vstore4(CONVERT_FLOAT4(mask3), 0, maskout + output_offset + query_seq_len4 + query_seq_len4 + query_seq_len4);
#else
vstore4(convert_int4(mask0), 0, maskout + output_offset);
vstore4(convert_int4(mask1), 0, maskout + output_offset + query_seq_len4);
vstore4(convert_int4(mask2), 0, maskout + output_offset + query_seq_len4 + query_seq_len4);
vstore4(convert_int4(mask3), 0, maskout + output_offset + query_seq_len4 + query_seq_len4 + query_seq_len4);
#endif
}
__kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS
__global const FLOAT *query, // [batch head_num head_dim_4 query_seq_len_4]
__global const FLOAT *past_key, // [batch kv_head_num head_dim_4 kv_max_length]
@ -486,15 +554,13 @@ __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS
out3 *= (float4)scale;
{
#if defined(ADD_MASK) || defined(SET_MASK)
int mask_offset = x4 * mask_key_seq_len + y4;
float4 mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
float4 mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
float4 mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
float4 mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset));
float4 mask0 = (float4)(mask_tmp0.s0, mask_tmp1.s0, mask_tmp2.s0, mask_tmp3.s0);
float4 mask1 = (float4)(mask_tmp0.s1, mask_tmp1.s1, mask_tmp2.s1, mask_tmp3.s1);
float4 mask2 = (float4)(mask_tmp0.s2, mask_tmp1.s2, mask_tmp2.s2, mask_tmp3.s2);
float4 mask3 = (float4)(mask_tmp0.s3, mask_tmp1.s3, mask_tmp2.s3, mask_tmp3.s3);
int query_seq_len4 = ((query_seq_len + 3) / 4) * 4;
int mask_clp = y4 + mask_key_seq_len - key_seq_len;
int mask_offset = mask_clp * query_seq_len4 + x4;
float4 mask0 = mask_clp >= 0 && mask_clp < mask_key_seq_len ? convert_float4(vload4(0, mask + mask_offset)) : 0; mask_offset += query_seq_len4;
float4 mask1 = mask_clp + 1 >= 0 && mask_clp + 1 < mask_key_seq_len? convert_float4(vload4(0, mask + mask_offset)) : 0; mask_offset += query_seq_len4;
float4 mask2 = mask_clp + 2 >= 0 && mask_clp + 2 < mask_key_seq_len? convert_float4(vload4(0, mask + mask_offset)) : 0; mask_offset += query_seq_len4;
float4 mask3 = mask_clp + 3 >= 0 && mask_clp + 3 < mask_key_seq_len? convert_float4(vload4(0, mask + mask_offset)) : 0;
#endif
#ifdef ADD_MASK

View File

@ -361,6 +361,73 @@ const char* attention_buf =
" vstore4(value_vec,0,past_value+output_offset);\n"
"#endif\n"
"}\n"
"__kernel void rearrange_mask_shortprefill(GLOBAL_SIZE_3_DIMS\n"
" #ifdef ADD_MASK\n"
" __global const FLOAT* mask,\n"
" __global FLOAT* maskout,\n"
" #else\n"
" __global const int* mask,// [1 1 query_seq_len mask_key_seq_len4]\n"
" __global int* maskout,// [1 1 mask_key_seq_len4 query_seq_len4]\n"
" #endif\n"
" __private const int query_seq_len,\n"
" __private const int mask_key_seq_len){\n"
" const int x=get_global_id(0); // query_seq_len4\n"
" const int y=get_global_id(1); // mask_key_seq_len4\n"
" const int z=get_global_id(2); // batch\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int x4=x << 2;\n"
" const int y4=y << 2;\n"
" float4 mask_tmp0,mask_tmp1,mask_tmp2,mask_tmp3;\n"
" float4 mask0,mask1,mask2,mask3;\n"
" int mask_offset=x4*mask_key_seq_len+y4;\n"
" if(x4+3<query_seq_len && y4+3<mask_key_seq_len){\n"
" mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=convert_float4(vload4(0,mask+mask_offset));\n"
" } else{\n"
" if(y4+3<mask_key_seq_len){\n"
" mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset));\n"
" } else if(y4+1 == mask_key_seq_len){\n"
" mask_tmp0=(float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0);\n"
" }else if(y4+2 == mask_key_seq_len){\n"
" mask_tmp0=(float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0);\n"
" }else if(y4+3 == mask_key_seq_len){\n"
" mask_tmp0=(float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0);\n"
" }\n"
" }\n"
" mask0=(float4)(mask_tmp0.s0,mask_tmp1.s0,mask_tmp2.s0,mask_tmp3.s0);\n"
" mask1=(float4)(mask_tmp0.s1,mask_tmp1.s1,mask_tmp2.s1,mask_tmp3.s1);\n"
" mask2=(float4)(mask_tmp0.s2,mask_tmp1.s2,mask_tmp2.s2,mask_tmp3.s2);\n"
" mask3=(float4)(mask_tmp0.s3,mask_tmp1.s3,mask_tmp2.s3,mask_tmp3.s3);\n"
" \n"
" int query_seq_len4=((query_seq_len+3)/4)*4;\n"
" int output_offset=y4*query_seq_len4+x4;\n"
" #ifdef ADD_MASK\n"
" vstore4(CONVERT_FLOAT4(mask0),0,maskout+output_offset);\n"
" vstore4(CONVERT_FLOAT4(mask1),0,maskout+output_offset+query_seq_len4);\n"
" vstore4(CONVERT_FLOAT4(mask2),0,maskout+output_offset+query_seq_len4+query_seq_len4);\n"
" vstore4(CONVERT_FLOAT4(mask3),0,maskout+output_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n"
" #else\n"
" vstore4(convert_int4(mask0),0,maskout+output_offset);\n"
" vstore4(convert_int4(mask1),0,maskout+output_offset+query_seq_len4);\n"
" vstore4(convert_int4(mask2),0,maskout+output_offset+query_seq_len4+query_seq_len4);\n"
" vstore4(convert_int4(mask3),0,maskout+output_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n"
" #endif\n"
"}\n"
"__kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *query,// [batch head_num head_dim_4 query_seq_len_4]\n"
" __global const FLOAT *past_key,// [batch kv_head_num head_dim_4 kv_max_length]\n"
@ -428,15 +495,13 @@ const char* attention_buf =
" out3 *= (float4)scale;\n"
" {\n"
" #if defined(ADD_MASK) || defined(SET_MASK)\n"
" int mask_offset=x4*mask_key_seq_len+y4;\n"
" float4 mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" float4 mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" float4 mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" float4 mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset));\n"
" float4 mask0=(float4)(mask_tmp0.s0,mask_tmp1.s0,mask_tmp2.s0,mask_tmp3.s0);\n"
" float4 mask1=(float4)(mask_tmp0.s1,mask_tmp1.s1,mask_tmp2.s1,mask_tmp3.s1);\n"
" float4 mask2=(float4)(mask_tmp0.s2,mask_tmp1.s2,mask_tmp2.s2,mask_tmp3.s2);\n"
" float4 mask3=(float4)(mask_tmp0.s3,mask_tmp1.s3,mask_tmp2.s3,mask_tmp3.s3);\n"
" int query_seq_len4=((query_seq_len+3)/4)*4;\n"
" int mask_clp=y4+mask_key_seq_len-key_seq_len;\n"
" int mask_offset=mask_clp*query_seq_len4+x4;\n"
" float4 mask0=mask_clp >= 0 && mask_clp<mask_key_seq_len ? convert_float4(vload4(0,mask+mask_offset)) : 0; mask_offset += query_seq_len4;\n"
" float4 mask1=mask_clp+1 >= 0 && mask_clp+1<mask_key_seq_len? convert_float4(vload4(0,mask+mask_offset)) : 0; mask_offset += query_seq_len4;\n"
" float4 mask2=mask_clp+2 >= 0 && mask_clp+2<mask_key_seq_len? convert_float4(vload4(0,mask+mask_offset)) : 0; mask_offset += query_seq_len4;\n"
" float4 mask3=mask_clp+3 >= 0 && mask_clp+3<mask_key_seq_len? convert_float4(vload4(0,mask+mask_offset)) : 0;\n"
" #endif\n"
" \n"
" #ifdef ADD_MASK\n"

View File

@ -30,21 +30,22 @@ __kernel void buffer_convert_to_buffer(GLOBAL_SIZE_3_DIMS
DEAL_NON_UNIFORM_DIM3(wh, c, n);
int w = wh % shape.w;
int h = wh / shape.w;
int input_offset, output_offset;
#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW
int input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC
int input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
int input_offset = ((((c / 4) * shape.x + n) * shape.z + h) * shape.w + w) * 4 + (c % 4);
input_offset = ((((c / 4) * shape.x + n) * shape.z + h) * shape.w + w) * 4 + (c % 4);
#endif
#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW
int output_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
output_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC
int output_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
output_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
int output_offset = ((((c / 4) * shape.x + n) * shape.z + h) * shape.w + w) * 4 + (c % 4);
output_offset = ((((c / 4) * shape.x + n) * shape.z + h) * shape.w + w) * 4 + (c % 4);
#endif
output_ptr[output_offset] = input_ptr[input_offset];

View File

@ -24,20 +24,21 @@ const char* buffer_convert_buf =
" DEAL_NON_UNIFORM_DIM3(wh,c,n);\n"
" int w=wh % shape.w;\n"
" int h=wh/shape.w;\n"
" int input_offset,output_offset;\n"
" \n"
"#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
" input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
" input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int input_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n"
" input_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n"
"#endif\n"
"#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int output_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
" output_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int output_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
" output_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int output_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n"
" output_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n"
"#endif\n"
" output_ptr[output_offset]=input_ptr[input_offset];\n"
"}\n"

View File

@ -42,21 +42,8 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
return; \
}
#define UNIT 4
#define MOD_NUM 15
#ifdef INPUT_CHANNEL_LEAVE
#define PADZEROSVEC(k, channel, data0, data1, data2, data3) \
data0 = (k << 2) < channel ? data0 : 0; \
data1 = (k << 2) + 1 < channel ? data1 : 0; \
data2 = (k << 2) + 2 < channel ? data2 : 0; \
data3 = (k << 2) + 3 < channel ? data3 : 0;
#else
#define PADZEROSVEC(k, channel, data0, data1, data2, data3)
#endif
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __read_only image2d_t input,
@ -188,17 +175,11 @@ void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __rea
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
__global const FLOAT *weights,
#else
__read_only image2d_t weights,
@ -210,10 +191,6 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
__private const int2 stride_shape,
__private const int output_width_4,
__private const int out_channel_blocks
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int inChannel
#endif
) {
const int output_channel_width_idx = get_global_id(0);
@ -223,13 +200,8 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
const int output_channel_block_idx = output_channel_width_idx / output_width_4;
const int output_width_block_idx = output_channel_width_idx % output_width_4;
#if (defined USE_LOW_BIT_WEIGHT_INT4)
int weight_ic_offset = output_channel_block_idx * 8;
int weight_oc_offset = out_channel_blocks * 8;
#else
int weight_ic_offset = output_channel_block_idx * 16;
int weight_oc_offset = out_channel_blocks * 16;
#endif
FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_block_idx, 0));
FLOAT4 out1 = out0;
@ -267,48 +239,9 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
int weight_offset = output_channel_block_idx * in_channel_block * 4 * 4;
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8;
COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_block_idx, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
#endif
int input_width_base = in_channel_block_idx * input_shape.y;
int weights_width_base = in_channel_block_idx << 2;
#if (defined USE_LOW_BIT_WEIGHT_INT8)
FLOAT16 weights = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
FLOAT4 weights0 = CONVERT_FLOAT4(weights.s0123) * scale0 + offset0;
FLOAT4 weights1 = CONVERT_FLOAT4(weights.s4567) * scale0 + offset0;
FLOAT4 weights2 = CONVERT_FLOAT4(weights.s89ab) * scale0 + offset0;
FLOAT4 weights3 = CONVERT_FLOAT4(weights.scdef) * scale0 + offset0;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar8 charWeightsInt4 = vload8(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset);
char4 charWeights0 = (char4)(0, 0, 0, 0);
char4 charWeights1 = (char4)(0, 0, 0, 0);
char4 charWeights2 = (char4)(0, 0, 0, 0);
char4 charWeights3 = (char4)(0, 0, 0, 0);
charWeights0.x = (charWeightsInt4.s0 >> 4) - 8;
charWeights0.y = (charWeightsInt4.s0 & MOD_NUM) - 8;
charWeights0.z = (charWeightsInt4.s1 >> 4) - 8;
charWeights0.w = (charWeightsInt4.s1 & MOD_NUM) - 8;
charWeights1.x = (charWeightsInt4.s2 >> 4) - 8;
charWeights1.y = (charWeightsInt4.s2 & MOD_NUM) - 8;
charWeights1.z = (charWeightsInt4.s3 >> 4) - 8;
charWeights1.w = (charWeightsInt4.s3 & MOD_NUM)- 8;
charWeights2.x = (charWeightsInt4.s4 >> 4) - 8;
charWeights2.y = (charWeightsInt4.s4 & MOD_NUM) - 8;
charWeights2.z = (charWeightsInt4.s5 >> 4) - 8;
charWeights2.w = (charWeightsInt4.s5 & MOD_NUM) - 8;
charWeights3.x = (charWeightsInt4.s6 >> 4) - 8;
charWeights3.y = (charWeightsInt4.s6 & MOD_NUM) - 8;
charWeights3.z = (charWeightsInt4.s7 >> 4) - 8;
charWeights3.w = (charWeightsInt4.s7 & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeights0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeights1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeights2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeights3), scale0, offset0);
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(weights_width_base, weights + weight_offset);
weights1 = vload4(weights_width_base + 1, weights + weight_offset);
weights2 = vload4(weights_width_base + 2, weights + weight_offset);
@ -319,7 +252,6 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_block_idx));
weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_block_idx));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx));
in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx));
in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx));
@ -368,17 +300,11 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
__global const FLOAT *weights,
#else
__read_only image2d_t weights,
@ -390,10 +316,6 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
__private const int2 stride_shape,
__private const int output_width_4,
__private const int out_channel_blocks
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int inChannel
#endif
) {
const int output_channel_width_idx = get_global_id(0);
@ -404,13 +326,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
const int output_width_block_idx = output_channel_width_idx % output_width_4;
const int output_channel_idx = output_channel_block_idx << 1;
#if (defined USE_LOW_BIT_WEIGHT_INT4)
int weight_ic_offset = output_channel_block_idx * 16;
int weight_oc_offset = out_channel_blocks * 8;
#else
int weight_ic_offset = output_channel_block_idx * 32;
int weight_oc_offset = out_channel_blocks * 16;
#endif
FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_idx, 0));
FLOAT4 out1 = out0;
FLOAT4 out2 = out0;
@ -457,16 +374,6 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
int weight_offset1 = weight_offset + in_channel_block * 4 * 4;
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8;
// already pack to 16, no need boundry protect
COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx + 1, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6);
COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7);
#endif
int input_width_base = in_channel_block_idx * input_shape.y;
int weights_width_base = in_channel_block_idx << 2;
@ -475,72 +382,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx));
in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
FLOAT16 weightsInt80 = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
#ifdef CHANNEL_BOUNDARY_PROTECT
FLOAT16 weightsInt81 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
#else
FLOAT16 weightsInt81 = CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset));
#endif
FLOAT4 weights0 = CONVERT_FLOAT4(weightsInt80.s0123) * scale0 + offset0;
FLOAT4 weights1 = CONVERT_FLOAT4(weightsInt80.s4567) * scale0 + offset0;
FLOAT4 weights2 = CONVERT_FLOAT4(weightsInt80.s89ab) * scale0 + offset0;
FLOAT4 weights3 = CONVERT_FLOAT4(weightsInt80.scdef) * scale0 + offset0;
FLOAT4 weights4 = CONVERT_FLOAT4(weightsInt81.s0123) * scale1 + offset1;
FLOAT4 weights5 = CONVERT_FLOAT4(weightsInt81.s4567) * scale1 + offset1;
FLOAT4 weights6 = CONVERT_FLOAT4(weightsInt81.s89ab) * scale1 + offset1;
FLOAT4 weights7 = CONVERT_FLOAT4(weightsInt81.scdef) * scale1 + offset1;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar16 charWeightsInt4 = vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset);
char4 charWeights0 = (char4)(0, 0, 0, 0);
char4 charWeights1 = (char4)(0, 0, 0, 0);
char4 charWeights2 = (char4)(0, 0, 0, 0);
char4 charWeights3 = (char4)(0, 0, 0, 0);
char4 charWeights4 = (char4)(0, 0, 0, 0);
char4 charWeights5 = (char4)(0, 0, 0, 0);
char4 charWeights6 = (char4)(0, 0, 0, 0);
char4 charWeights7 = (char4)(0, 0, 0, 0);
charWeights0.x = (charWeightsInt4.s0 >> 4) - 8;
charWeights0.y = (charWeightsInt4.s0 & MOD_NUM) - 8;
charWeights0.z = (charWeightsInt4.s1 >> 4) - 8;
charWeights0.w = (charWeightsInt4.s1 & MOD_NUM) - 8;
charWeights1.x = (charWeightsInt4.s2 >> 4) - 8;
charWeights1.y = (charWeightsInt4.s2 & MOD_NUM) - 8;
charWeights1.z = (charWeightsInt4.s3 >> 4) - 8;
charWeights1.w = (charWeightsInt4.s3 & MOD_NUM) - 8;
charWeights2.x = (charWeightsInt4.s4 >> 4) - 8;
charWeights2.y = (charWeightsInt4.s4 & MOD_NUM) - 8;
charWeights2.z = (charWeightsInt4.s5 >> 4) - 8;
charWeights2.w = (charWeightsInt4.s5 & MOD_NUM) - 8;
charWeights3.x = (charWeightsInt4.s6 >> 4) - 8;
charWeights3.y = (charWeightsInt4.s6 & MOD_NUM) - 8;
charWeights3.z = (charWeightsInt4.s7 >> 4) - 8;
charWeights3.w = (charWeightsInt4.s7 & MOD_NUM) - 8;
charWeights4.x = (charWeightsInt4.s8 >> 4) - 8;
charWeights4.y = (charWeightsInt4.s8 & MOD_NUM) - 8;
charWeights4.z = (charWeightsInt4.s9 >> 4) - 8;
charWeights4.w = (charWeightsInt4.s9 & MOD_NUM) - 8;
charWeights5.x = (charWeightsInt4.sa >> 4) - 8;
charWeights5.y = (charWeightsInt4.sa & MOD_NUM) - 8;
charWeights5.z = (charWeightsInt4.sb >> 4) - 8;
charWeights5.w = (charWeightsInt4.sb & MOD_NUM) - 8;
charWeights6.x = (charWeightsInt4.sc >> 4) - 8;
charWeights6.y = (charWeightsInt4.sc & MOD_NUM) - 8;
charWeights6.z = (charWeightsInt4.sd >> 4) - 8;
charWeights6.w = (charWeightsInt4.sd & MOD_NUM) - 8;
charWeights7.x = (charWeightsInt4.se >> 4) - 8;
charWeights7.y = (charWeightsInt4.se & MOD_NUM) - 8;
charWeights7.z = (charWeightsInt4.sf >> 4) - 8;
charWeights7.w = (charWeightsInt4.sf & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeights0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeights1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeights2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeights3), scale0, offset0);
weights4 = mad(CONVERT_FLOAT4(charWeights4), scale1, offset1);
weights5 = mad(CONVERT_FLOAT4(charWeights5), scale1, offset1);
weights6 = mad(CONVERT_FLOAT4(charWeights6), scale1, offset1);
weights7 = mad(CONVERT_FLOAT4(charWeights7), scale1, offset1);
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(weights_width_base, weights + weight_offset);
weights1 = vload4(weights_width_base + 1, weights + weight_offset);
weights2 = vload4(weights_width_base + 2, weights + weight_offset);
@ -568,8 +410,6 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights6 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_idx + 1));
weights7 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_idx + 1));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
PADZEROSVEC(in_channel_block_idx, inChannel, weights4, weights5, weights6, weights7);
CALCULATE_OUTPUT(0);
CALCULATE_OUTPUT(1);
@ -646,17 +486,11 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
__global const FLOAT *weights,
#else
__read_only image2d_t weights,
@ -675,10 +509,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
__private const int out_width_blocks,
__private const int out_channel_blocks,
__private const int out_height_blocks
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int inChannel
#endif
) {
const int output_channel_width_idx = get_global_id(0);
@ -720,23 +550,15 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y);
#endif
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)
#ifdef USE_BUFFER
const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4;
#endif
FLOAT4 in0, in1, in2, in3;
FLOAT4 weights0, weights1, weights2, weights3;
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8;
COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
#endif
const int in_idx = mul24(in_channel_block_idx, input_shape.y);
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)
#ifdef USE_BUFFER
int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + kh_start)*weights_shape.y + 0) * 4;
#else
int weights_x_idx = in_channel_block_idx << 2;
@ -751,47 +573,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
READ_INPUT_IMAGE(2, 0);
READ_INPUT_IMAGE(3, 0);
#if (defined USE_LOW_BIT_WEIGHT_INT8)
char4 charWeight0 = vload4(0, kernel_ptr+weight_offset);
char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2);
char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3);
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2);
uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2);
uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2);
uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2);
char4 charWeight0 = (char4)(0, 0, 0, 0);
char4 charWeight1 = (char4)(0, 0, 0, 0);
char4 charWeight2 = (char4)(0, 0, 0, 0);
char4 charWeight3 = (char4)(0, 0, 0, 0);
charWeight0.x = (charWeightInt40.s0 >> 4) - 8;
charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8;
charWeight0.z = (charWeightInt40.s1 >> 4) - 8;
charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8;
charWeight1.x = (charWeightInt41.s0 >> 4) - 8;
charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8;
charWeight1.z = (charWeightInt41.s1 >> 4) - 8;
charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8;
charWeight2.x = (charWeightInt42.s0 >> 4) - 8;
charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8;
charWeight2.z = (charWeightInt42.s1 >> 4) - 8;
charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8;
charWeight3.x = (charWeightInt43.s0 >> 4) - 8;
charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8;
charWeight3.z = (charWeightInt43.s1 >> 4) - 8;
charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(0, weights+weight_offset);
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
@ -803,7 +585,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
CALCULATE_OUTPUT(0);
CALCULATE_OUTPUT(1);
CALCULATE_OUTPUT(2);
@ -814,47 +595,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
in1 = in2;
in2 = in3;
READ_INPUT_IMAGE(3, w);
#if (defined USE_LOW_BIT_WEIGHT_INT8)
char4 charWeight0 = vload4(0, kernel_ptr+weight_offset);
char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2);
char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3);
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2);
uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2);
uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2);
uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2);
char4 charWeight0 = (char4)(0, 0, 0, 0);
char4 charWeight1 = (char4)(0, 0, 0, 0);
char4 charWeight2 = (char4)(0, 0, 0, 0);
char4 charWeight3 = (char4)(0, 0, 0, 0);
charWeight0.x = (charWeightInt40.s0 >> 4) - 8;
charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8;
charWeight0.z = (charWeightInt40.s1 >> 4) - 8;
charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8;
charWeight1.x = (charWeightInt41.s0 >> 4) - 8;
charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8;
charWeight1.z = (charWeightInt41.s1 >> 4) - 8;
charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8;
charWeight2.x = (charWeightInt42.s0 >> 4) - 8;
charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8;
charWeight2.z = (charWeightInt42.s1 >> 4) - 8;
charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8;
charWeight3.x = (charWeightInt43.s0 >> 4) - 8;
charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8;
charWeight3.z = (charWeightInt43.s1 >> 4) - 8;
charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(0, weights+weight_offset);
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
@ -866,7 +607,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
CALCULATE_OUTPUT(0);
CALCULATE_OUTPUT(1);
CALCULATE_OUTPUT(2);
@ -879,47 +619,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
READ_INPUT_IMAGE(1, input_width_base);
READ_INPUT_IMAGE(2, input_width_base);
READ_INPUT_IMAGE(3, input_width_base);
#if (defined USE_LOW_BIT_WEIGHT_INT8)
char4 charWeight0 = vload4(0, kernel_ptr+weight_offset);
char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2);
char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3);
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2);
uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2);
uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2);
uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2);
char4 charWeight0 = (char4)(0, 0, 0, 0);
char4 charWeight1 = (char4)(0, 0, 0, 0);
char4 charWeight2 = (char4)(0, 0, 0, 0);
char4 charWeight3 = (char4)(0, 0, 0, 0);
charWeight0.x = (charWeightInt40.s0 >> 4) - 8;
charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8;
charWeight0.z = (charWeightInt40.s1 >> 4) - 8;
charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8;
charWeight1.x = (charWeightInt41.s0 >> 4) - 8;
charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8;
charWeight1.z = (charWeightInt41.s1 >> 4) - 8;
charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8;
charWeight2.x = (charWeightInt42.s0 >> 4) - 8;
charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8;
charWeight2.z = (charWeightInt42.s1 >> 4) - 8;
charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8;
charWeight3.x = (charWeightInt43.s0 >> 4) - 8;
charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8;
charWeight3.z = (charWeightInt43.s1 >> 4) - 8;
charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(0, weights+weight_offset);
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
@ -931,7 +631,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
CALCULATE_OUTPUT(0);
CALCULATE_OUTPUT(1);
CALCULATE_OUTPUT(2);
@ -978,17 +677,11 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
__global const FLOAT *weights,
#else
__read_only image2d_t weights,
@ -1007,10 +700,6 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
__private const int out_width_blocks,
__private const int out_channel_blocks,
__private const int out_height_blocks
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int inChannel
#endif
) {
const int output_channel_width_idx = get_global_id(0);
@ -1036,7 +725,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
FLOAT4 out6 = out4;
FLOAT4 out7 = out4;
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)
#ifdef USE_BUFFER
const int weight_oc_offset = weights_shape.x * weights_shape.y * 4;
const int weight_ic_offset = out_channel_blocks * weight_oc_offset;
#endif
@ -1054,18 +743,8 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
FLOAT4 in0, in1, in2, in3;
FLOAT4 weights0, weights1, weights2, weights3, weights4, weights5, weights6, weights7;
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8;
COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx + 1, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6);
COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7);
#endif
const int in_idx = mul24(in_channel_block_idx, input_shape.y);
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)
#ifdef USE_BUFFER
int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4;
#else
int weights_x_idx = in_channel_block_idx << 2;
@ -1083,92 +762,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
in1 = RI_F(input, SAMPLER, (int2)(w0, h1));
in2 = RI_F(input, SAMPLER, (int2)(w0, h2));
in3 = RI_F(input, SAMPLER, (int2)(w0, h3));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
char4 charWeight0 = vload4(0, kernel_ptr+weight_offset);
char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset);
char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset*2);
char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset*3);
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
#ifdef CHANNEL_BOUNDARY_PROTECT
charWeight0 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
charWeight1 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);
charWeight2 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);
charWeight3 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);
#else
charWeight0 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);
charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);
charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);
#endif
weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1);
weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1);
weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1);
weights7 = mad(CONVERT_FLOAT4(charWeight3), scale1, offset1);
weight_offset += 4;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2);
uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset/2);
uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset*2/2);
uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset*3/2);
char4 charWeight0 = (char4)(0, 0, 0, 0);
char4 charWeight1 = (char4)(0, 0, 0, 0);
char4 charWeight2 = (char4)(0, 0, 0, 0);
char4 charWeight3 = (char4)(0, 0, 0, 0);
charWeight0.x = (charWeightInt40.s0 >> 4) - 8;
charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8;
charWeight0.z = (charWeightInt40.s1 >> 4) - 8;
charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8;
charWeight1.x = (charWeightInt41.s0 >> 4) - 8;
charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8;
charWeight1.z = (charWeightInt41.s1 >> 4) - 8;
charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8;
charWeight2.x = (charWeightInt42.s0 >> 4) - 8;
charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8;
charWeight2.z = (charWeightInt42.s1 >> 4) - 8;
charWeight2.w = (charWeightInt42.s1 & MOD_NUM)- 8;
charWeight3.x = (charWeightInt43.s0 >> 4) - 8;
charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8;
charWeight3.z = (charWeightInt43.s1 >> 4) - 8;
charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2);
charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);
charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);
charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);
charWeight0 = (char4)(0, 0, 0, 0);
charWeight1 = (char4)(0, 0, 0, 0);
charWeight2 = (char4)(0, 0, 0, 0);
charWeight3 = (char4)(0, 0, 0, 0);
charWeight0.x = (charWeightInt40.s0 >> 4) - 8;
charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8;
charWeight0.z = (charWeightInt40.s1 >> 4) - 8;
charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8;
charWeight1.x = (charWeightInt41.s0 >> 4) - 8;
charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8;
charWeight1.z = (charWeightInt41.s1 >> 4) - 8;
charWeight1.w = (charWeightInt41.s1 & MOD_NUM)- 8;
charWeight2.x = (charWeightInt42.s0 >> 4) - 8;
charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8;
charWeight2.z = (charWeightInt42.s1 >> 4) - 8;
charWeight2.w = (charWeightInt42.s1 & MOD_NUM)- 8;
charWeight3.x = (charWeightInt43.s0 >> 4) - 8;
charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8;
charWeight3.z = (charWeightInt43.s1 >> 4) - 8;
charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8;
weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1);
weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1);
weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1);
weights7 = mad(CONVERT_FLOAT4(charWeight3), scale1, offset1);
weight_offset += 4;
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(0, weights+weight_offset);
weights1 = vload4(0, weights+weight_offset+weight_ic_offset);
weights2 = vload4(0, weights+weight_offset+weight_ic_offset*2);
@ -1195,8 +789,6 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights6 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weight_size + weights_y_idx));
weights7 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weight_size + weights_y_idx++));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
PADZEROSVEC(in_channel_block_idx, inChannel, weights4, weights5, weights6, weights7);
CALCULATE_OUTPUT(0);
CALCULATE_OUTPUT(1);
@ -1279,17 +871,11 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *kernel_ptr,
__global const float *dequantScaleOffset,
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
__global const FLOAT *weights,
#else
__read_only image2d_t weights,
@ -1308,10 +894,6 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
__private const int out_width_blocks,
__private const int out_channel_blocks,
__private const int out_height_blocks
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int inChannel
#endif
) {
const int output_channel_width_idx = get_global_id(0);
@ -1344,18 +926,12 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
FLOAT4 in0, in1, in2, in3;
FLOAT4 weights0, weights1, weights2, weights3;
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)
#ifdef USE_BUFFER
const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4;
#endif
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8;
COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex));
COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6);
COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7);
#endif
const int in_idx = mul24(in_channel_block_idx, input_shape.y);
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)
#ifdef USE_BUFFER
int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4;
#else
int weights_x_idx = in_channel_block_idx << 2;
@ -1374,47 +950,7 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
in2 = RI_F(input, SAMPLER, (int2)(w0, h2));
in3 = RI_F(input, SAMPLER, (int2)(w0, h3));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
char4 charWeight0 = vload4(0, kernel_ptr+weight_offset);
char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2);
char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3);
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2);
uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2);
uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2);
uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2);
char4 charWeight0 = (char4)(0, 0, 0, 0);
char4 charWeight1 = (char4)(0, 0, 0, 0);
char4 charWeight2 = (char4)(0, 0, 0, 0);
char4 charWeight3 = (char4)(0, 0, 0, 0);
charWeight0.x = (charWeightInt40.s0 >> 4) - 8;
charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8;
charWeight0.z = (charWeightInt40.s1 >> 4) - 8;
charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8;
charWeight1.x = (charWeightInt41.s0 >> 4) - 8;
charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8;
charWeight1.z = (charWeightInt41.s1 >> 4) - 8;
charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8;
charWeight2.x = (charWeightInt42.s0 >> 4) - 8;
charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8;
charWeight2.z = (charWeightInt42.s1 >> 4) - 8;
charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8;
charWeight3.x = (charWeightInt43.s0 >> 4) - 8;
charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8;
charWeight3.z = (charWeightInt43.s1 >> 4) - 8;
charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8;
weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0);
weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0);
weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0);
weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0);
weight_offset += 4;
#elif (defined USE_BUFFER)
#ifdef USE_BUFFER
weights0 = vload4(0, weights+weight_offset);
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
@ -1426,7 +962,6 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
#endif
PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3);
CALCULATE_OUTPUT(0);
CALCULATE_OUTPUT(1);

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@
__kernel
void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS
__global const FLOAT *input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
__global const char *weight,
#else
__global const uchar *weight,
@ -88,7 +88,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS
const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y;
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
char4 charWeight0 = vload4(filter_w_inc, weight+weight_offset);
char4 charWeight1 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset*2);
@ -155,7 +155,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS
__kernel
void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS
__global const FLOAT *input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
__global const char *weight,
#else
__global const uchar *weight,
@ -224,7 +224,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS
COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4((in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base));
COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4((in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
char4 charWeight0 = vload4(0, weight+weight_offset);
char4 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset*2);
@ -303,7 +303,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS
__kernel
void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS
__global const FLOAT *input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
__global const char *weight,
#else
__global const uchar *weight,
@ -380,7 +380,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS
COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4((in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base));
COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4((in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
char4 charWeight0 = vload4(0, weight+weight_offset);
char4 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset);
char4 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset*2);
@ -482,7 +482,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS
__kernel
void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
__global const FLOAT *input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
__global const char *weight,
#else
__global const uchar *weight,
@ -579,7 +579,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4((in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base));
COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4((in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
char4 charWeight0 = vload4(0, weight+weight_offset);
char4 charWeight1 = vload4(0, weight+weight_offset+weight_ic_offset);
char4 charWeight2 = vload4(0, weight+weight_offset+weight_ic_offset*2);
@ -640,7 +640,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS
out3 = mad(in3.z, weight2, out3);
out3 = mad(in3.w, weight3, out3);
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
#ifdef CHANNEL_BOUNDARY_PROTECT
charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset);
charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset);

View File

@ -16,7 +16,7 @@ const char* conv_2d_int_buf =
"__kernel\n"
"void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
@ -77,7 +77,7 @@ const char* conv_2d_int_buf =
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n"
" \n"
" const int filter_w_inc=(ix-in_w_idx_start)/dilate_hw.y;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" char4 charWeight0=vload4(filter_w_inc,weight+weight_offset);\n"
" char4 charWeight1=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset*2);\n"
@ -139,7 +139,7 @@ const char* conv_2d_int_buf =
"__kernel\n"
"void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
@ -202,7 +202,7 @@ const char* conv_2d_int_buf =
" COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n"
@ -277,7 +277,7 @@ const char* conv_2d_int_buf =
"__kernel\n"
"void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
@ -345,7 +345,7 @@ const char* conv_2d_int_buf =
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n"
@ -442,7 +442,7 @@ const char* conv_2d_int_buf =
"__kernel\n"
"void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" __global const char *weight,\n"
"#else\n"
" __global const uchar *weight,\n"
@ -531,7 +531,7 @@ const char* conv_2d_int_buf =
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx,input+inp_offset_base));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx,input+inp_offset_base));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" char4 charWeight0=vload4(0,weight+weight_offset);\n"
" char4 charWeight1=vload4(0,weight+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,weight+weight_offset+weight_ic_offset*2);\n"
@ -591,7 +591,7 @@ const char* conv_2d_int_buf =
" out3=mad(in3.z,weight2,out3);\n"
" out3=mad(in3.w,weight3,out3);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
"#if QUANT_BIT == 8\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" charWeight0=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset);\n"
" charWeight1=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n"

File diff suppressed because it is too large Load Diff

View File

@ -13,15 +13,8 @@ const char* conv_2d =
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define UNIT 4\n"
"#define MOD_NUM 15\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,__read_only image2d_t input,\n"
@ -131,17 +124,11 @@ const char* conv_2d =
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
@ -153,23 +140,14 @@ const char* conv_2d =
" __private const int2 stride_shape,\n"
" __private const int output_width_4,\n"
" __private const int out_channel_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n"
" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_ic_offset=output_channel_block_idx*8;\n"
" int weight_oc_offset=out_channel_blocks*8;\n"
"#else\n"
" int weight_ic_offset=output_channel_block_idx*16;\n"
" int weight_oc_offset=out_channel_blocks*16;\n"
"#endif\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_block_idx,0));\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
@ -201,48 +179,9 @@ const char* conv_2d =
" FLOAT4 weights3;\n"
" int weight_offset=output_channel_block_idx*in_channel_block*4*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" int input_width_base=in_channel_block_idx*input_shape.y;\n"
" int weights_width_base=in_channel_block_idx << 2;\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" FLOAT4 weights0=CONVERT_FLOAT4(weights.s0123)*scale0+offset0;\n"
" FLOAT4 weights1=CONVERT_FLOAT4(weights.s4567)*scale0+offset0;\n"
" FLOAT4 weights2=CONVERT_FLOAT4(weights.s89ab)*scale0+offset0;\n"
" FLOAT4 weights3=CONVERT_FLOAT4(weights.scdef)*scale0+offset0;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n"
" char4 charWeights0=(char4)(0,0,0,0);\n"
" char4 charWeights1=(char4)(0,0,0,0);\n"
" char4 charWeights2=(char4)(0,0,0,0);\n"
" char4 charWeights3=(char4)(0,0,0,0);\n"
" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n"
" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n"
" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n"
" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)- 8;\n"
" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n"
" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n"
" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n"
" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(weights_width_base,weights+weight_offset);\n"
" weights1=vload4(weights_width_base+1,weights+weight_offset);\n"
" weights2=vload4(weights_width_base+2,weights+weight_offset);\n"
@ -253,7 +192,6 @@ const char* conv_2d =
" weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_block_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_block_idx));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n"
@ -296,17 +234,11 @@ const char* conv_2d =
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
@ -318,10 +250,6 @@ const char* conv_2d =
" __private const int2 stride_shape,\n"
" __private const int output_width_4,\n"
" __private const int out_channel_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
@ -329,13 +257,8 @@ const char* conv_2d =
" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n"
" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n"
" const int output_channel_idx=output_channel_block_idx << 1;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_ic_offset=output_channel_block_idx*16;\n"
" int weight_oc_offset=out_channel_blocks*8;\n"
"#else\n"
" int weight_ic_offset=output_channel_block_idx*32;\n"
" int weight_oc_offset=out_channel_blocks*16;\n"
"#endif\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_idx,0));\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
@ -377,16 +300,6 @@ const char* conv_2d =
" int weight_offset=output_channel_idx*in_channel_block*4*4;\n"
" int weight_offset1=weight_offset+in_channel_block*4*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" // already pack to 16,no need boundry protect\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
"#endif\n"
" \n"
" int input_width_base=in_channel_block_idx*input_shape.y;\n"
" int weights_width_base=in_channel_block_idx << 2;\n"
@ -394,72 +307,7 @@ const char* conv_2d =
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weightsInt80=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" FLOAT16 weightsInt81=output_channel_idx+1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" #else\n"
" FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" #endif\n"
" FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n"
" FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n"
" FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n"
" FLOAT4 weights3=CONVERT_FLOAT4(weightsInt80.scdef)*scale0+offset0;\n"
" FLOAT4 weights4=CONVERT_FLOAT4(weightsInt81.s0123)*scale1+offset1;\n"
" FLOAT4 weights5=CONVERT_FLOAT4(weightsInt81.s4567)*scale1+offset1;\n"
" FLOAT4 weights6=CONVERT_FLOAT4(weightsInt81.s89ab)*scale1+offset1;\n"
" FLOAT4 weights7=CONVERT_FLOAT4(weightsInt81.scdef)*scale1+offset1;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar16 charWeightsInt4=vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n"
" char4 charWeights0=(char4)(0,0,0,0);\n"
" char4 charWeights1=(char4)(0,0,0,0);\n"
" char4 charWeights2=(char4)(0,0,0,0);\n"
" char4 charWeights3=(char4)(0,0,0,0);\n"
" char4 charWeights4=(char4)(0,0,0,0);\n"
" char4 charWeights5=(char4)(0,0,0,0);\n"
" char4 charWeights6=(char4)(0,0,0,0);\n"
" char4 charWeights7=(char4)(0,0,0,0);\n"
" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n"
" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n"
" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n"
" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)-8;\n"
" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n"
" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n"
" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n"
" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n"
" charWeights4.x=(charWeightsInt4.s8 >> 4)-8;\n"
" charWeights4.y=(charWeightsInt4.s8 & MOD_NUM)-8;\n"
" charWeights4.z=(charWeightsInt4.s9 >> 4)-8;\n"
" charWeights4.w=(charWeightsInt4.s9 & MOD_NUM)-8;\n"
" charWeights5.x=(charWeightsInt4.sa >> 4)-8;\n"
" charWeights5.y=(charWeightsInt4.sa & MOD_NUM)-8;\n"
" charWeights5.z=(charWeightsInt4.sb >> 4)-8;\n"
" charWeights5.w=(charWeightsInt4.sb & MOD_NUM)-8;\n"
" charWeights6.x=(charWeightsInt4.sc >> 4)-8;\n"
" charWeights6.y=(charWeightsInt4.sc & MOD_NUM)-8;\n"
" charWeights6.z=(charWeightsInt4.sd >> 4)-8;\n"
" charWeights6.w=(charWeightsInt4.sd & MOD_NUM)-8;\n"
" charWeights7.x=(charWeightsInt4.se >> 4)-8;\n"
" charWeights7.y=(charWeightsInt4.se & MOD_NUM)-8;\n"
" charWeights7.z=(charWeightsInt4.sf >> 4)-8;\n"
" charWeights7.w=(charWeightsInt4.sf & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n"
" weights4=mad(CONVERT_FLOAT4(charWeights4),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeights5),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeights6),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeights7),scale1,offset1);\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(weights_width_base,weights+weight_offset);\n"
" weights1=vload4(weights_width_base+1,weights+weight_offset);\n"
" weights2=vload4(weights_width_base+2,weights+weight_offset);\n"
@ -486,8 +334,6 @@ const char* conv_2d =
" weights6=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx+1));\n"
" weights7=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx+1));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
@ -558,17 +404,11 @@ const char* conv_2d =
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
@ -587,10 +427,6 @@ const char* conv_2d =
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
@ -625,22 +461,14 @@ const char* conv_2d =
" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n"
" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0),weights_shape.y);\n"
"#endif\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n"
"#endif\n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" \n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+kh_start)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
@ -655,47 +483,7 @@ const char* conv_2d =
" READ_INPUT_IMAGE(2,0);\n"
" READ_INPUT_IMAGE(3,0);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
@ -707,7 +495,6 @@ const char* conv_2d =
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
@ -718,47 +505,7 @@ const char* conv_2d =
" in1=in2;\n"
" in2=in3;\n"
" READ_INPUT_IMAGE(3,w);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
@ -770,7 +517,6 @@ const char* conv_2d =
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
@ -783,47 +529,7 @@ const char* conv_2d =
" READ_INPUT_IMAGE(1,input_width_base);\n"
" READ_INPUT_IMAGE(2,input_width_base);\n"
" READ_INPUT_IMAGE(3,input_width_base);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
@ -835,7 +541,6 @@ const char* conv_2d =
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx)); \n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
@ -877,17 +582,11 @@ const char* conv_2d =
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
@ -906,10 +605,6 @@ const char* conv_2d =
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
@ -931,7 +626,7 @@ const char* conv_2d =
" FLOAT4 out5=out4;\n"
" FLOAT4 out6=out4;\n"
" FLOAT4 out7=out4;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" const int weight_oc_offset=weights_shape.x*weights_shape.y*4;\n"
" const int weight_ic_offset=out_channel_blocks*weight_oc_offset;\n"
"#endif\n"
@ -948,18 +643,8 @@ const char* conv_2d =
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3,weights4,weights5,weights6,weights7;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
" \n"
"#endif\n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
@ -977,91 +662,7 @@ const char* conv_2d =
" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n"
" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n"
" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" charWeight0=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" charWeight1=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" \n"
" #else\n"
" charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" #endif\n"
" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" charWeightInt40=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n"
" charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n"
" charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n"
" charWeight0=(char4)(0,0,0,0);\n"
" charWeight1=(char4)(0,0,0,0);\n"
" charWeight2=(char4)(0,0,0,0);\n"
" charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_ic_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_ic_offset*2);\n"
@ -1088,8 +689,6 @@ const char* conv_2d =
" weights6=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weight_size+weights_y_idx));\n"
" weights7=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weight_size+weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n"
" \n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
@ -1167,17 +766,11 @@ const char* conv_2d =
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
@ -1196,10 +789,6 @@ const char* conv_2d =
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
@ -1228,18 +817,12 @@ const char* conv_2d =
" \n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n"
"#endif\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
@ -1258,47 +841,7 @@ const char* conv_2d =
" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n"
" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
"#ifdef USE_BUFFER\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
@ -1310,7 +853,6 @@ const char* conv_2d =
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"

View File

@ -23,7 +23,7 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter,
@ -130,7 +130,7 @@ void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_
}
__kernel
#if SET_ATTRIBUTE
#ifdef SET_ATTRIBUTE
__attribute__((work_group_size_hint(16, 16, 1)))
#endif
void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter,

View File

@ -10,7 +10,7 @@ const char* depthwise_conv2d =
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,__read_only image2d_t filter,\n"
@ -102,7 +102,7 @@ const char* depthwise_conv2d =
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"#ifdef SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,__read_only image2d_t filter,\n"

View File

@ -289,82 +289,24 @@ __kernel void gemmWinogradW2(__read_only image2d_t uInput, __read_only image2d_t
__kernel void gemm_conv(GLOBAL_SIZE_DIM2
__read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *weight,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *weight,
__global const float *dequantScaleOffset,
#else
__global const FLOAT *weight,
#endif
__read_only image2d_t bias,
__write_only image2d_t output,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int batch
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int srcChannel
#endif
) {
int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b
UNIFORM_BOUNDRY_CHECK(pos.x, pos.y);
FLOAT4 out = RI_F(bias, SAMPLER, (int2)(pos.x, 0));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
int weight_offset = pos.x * 16;
int weight_oc_offset = dstChannelC4 * 16;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
int weight_offset = pos.x * 8;
int weight_oc_offset = dstChannelC4 * 8;
#else
int weight_offset = pos.x * 16;
int weight_oc_offset = dstChannelC4 * 16;
#endif
for (int k = 0; k < srcChannelC4; ++k) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (k * 4) / blockDim * dstChannelC4 * 8;
COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex));
COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6);
COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7);
#endif
FLOAT4 in = RI_F(input, SAMPLER, (int2)(k, pos.y));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset);
char16 charWeights = 0;
charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8;
charWeights.s1 = (charWeightsInt4.s0 & 15) - 8;
charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8;
charWeights.s3 = (charWeightsInt4.s1 & 15) - 8;
charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8;
charWeights.s5 = (charWeightsInt4.s2 & 15) - 8;
charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8;
charWeights.s7 = (charWeightsInt4.s3 & 15) - 8;
charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8;
charWeights.s9 = (charWeightsInt4.s4 & 15) - 8;
charWeights.sa = (charWeightsInt4.s5 >> 4) - 8;
charWeights.sb = (charWeightsInt4.s5 & 15) - 8;
charWeights.sc = (charWeightsInt4.s6 >> 4) - 8;
charWeights.sd = (charWeightsInt4.s6 & 15) - 8;
charWeights.se = (charWeightsInt4.s7 >> 4) - 8;
charWeights.sf = (charWeightsInt4.s7 & 15) - 8;
FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset;
#else
FLOAT16 weights = vload16(0, weight + weight_offset + k * weight_oc_offset);
#endif
PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef);
out = mad((FLOAT4)in.x, (FLOAT4)weights.s0123, out);
out = mad((FLOAT4)in.y, (FLOAT4)weights.s4567, out);
@ -385,24 +327,12 @@ __kernel void gemm_conv(GLOBAL_SIZE_DIM2
__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2
__read_only image2d_t input,
#if (defined USE_LOW_BIT_WEIGHT_INT8)
__global const char *weight,
__global const float *dequantScaleOffset,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
__global const uchar *weight,
__global const float *dequantScaleOffset,
#else
__global const FLOAT *weight,
#endif
__read_only image2d_t bias,
__write_only image2d_t output,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int batch
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
,__private const int blockDim
,__private const int srcChannel
#endif
) {
int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b
UNIFORM_BOUNDRY_CHECK(pos.x, pos.y);
@ -412,58 +342,13 @@ __kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2
FLOAT4 bias0 = RI_F(bias, SAMPLER, (int2)(pos.x, 0));
FLOAT4 out0 = bias0, out1 = bias0;
#if (defined USE_LOW_BIT_WEIGHT_INT8)
int weight_offset = pos.x * 16;
int weight_oc_offset = dstChannelC4 * 16;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
int weight_offset = pos.x * 8;
int weight_oc_offset = dstChannelC4 * 8;
#else
int weight_offset = pos.x * 16;
int weight_oc_offset = dstChannelC4 * 16;
#endif
for (int k = 0; k < srcChannelC4; ++k) {
#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)
int kindex = (k * 4) / blockDim * dstChannelC4 * 8;
COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex));
COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6);
COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7);
#endif
FLOAT4 in0 = RI_F(input, SAMPLER, (int2)(k, pos_y));
FLOAT4 in1 = RI_F(input, SAMPLER, (int2)(k, pos_y + 1));
#if (defined USE_LOW_BIT_WEIGHT_INT8)
FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset;
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset);
char16 charWeights = 0;
charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8;
charWeights.s1 = (charWeightsInt4.s0 & 15) - 8;
charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8;
charWeights.s3 = (charWeightsInt4.s1 & 15) - 8;
charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8;
charWeights.s5 = (charWeightsInt4.s2 & 15) - 8;
charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8;
charWeights.s7 = (charWeightsInt4.s3 & 15) - 8;
charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8;
charWeights.s9 = (charWeightsInt4.s4 & 15) - 8;
charWeights.sa = (charWeightsInt4.s5 >> 4) - 8;
charWeights.sb = (charWeightsInt4.s5 & 15) - 8;
charWeights.sc = (charWeightsInt4.s6 >> 4) - 8;
charWeights.sd = (charWeightsInt4.s6 & 15) - 8;
charWeights.se = (charWeightsInt4.s7 >> 4) - 8;
charWeights.sf = (charWeightsInt4.s7 & 15) - 8;
FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset;
#else
FLOAT16 weights = vload16(0, weight + weight_offset + k * weight_oc_offset);
#endif
PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef);
out0 = mad((FLOAT4)in0.x, (FLOAT4)weights.s0123, out0);
out0 = mad((FLOAT4)in0.y, (FLOAT4)weights.s4567, out0);

View File

@ -18,9 +18,9 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2
#ifdef USE_IMAGE
__read_only image2d_t weight,
#else
#if (defined USE_LOW_BIT_WEIGHT_INT8)
#if QUANT_BIT == 8
__global const char *weight,
#elif (defined USE_LOW_BIT_WEIGHT_INT4)
#else
__global const uchar *weight,
#endif
#endif
@ -37,7 +37,7 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2
UNIFORM_BOUNDRY_CHECK(x, y);
#if (defined USE_LOW_BIT_WEIGHT_INT4)
#if QUANT_BIT == 4
const int ic = x << 2;
const int oc = y << 3;
const int output_offset = ic * outputChannelAlign + oc;
@ -90,7 +90,7 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2
vstore8(CONVERT_FLOAT8(weights1), 0, output+output_offset+outputChannelAlign);
vstore8(CONVERT_FLOAT8(weights2), 0, output+output_offset+2*outputChannelAlign);
vstore8(CONVERT_FLOAT8(weights3), 0, output+output_offset+3*outputChannelAlign);
#else
#elif QUANT_BIT == 8
const int ic = x << 1;
const int oc = y << 3;
const int output_offset = ic * outputChannelAlign + oc;
@ -125,6 +125,83 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2
#endif
}
__kernel void gemm_c4nhw4_to_nhwc(GLOBAL_SIZE_DIM2
__global const FLOAT* input,
__global FLOAT* output,
__private const int bhw,
__private const int channel,
__private const int channelAlign
){
const int x = get_global_id(0); //b/4
const int y = get_global_id(1); //c/4
UNIFORM_BOUNDRY_CHECK(x, y);
const int out_b_idx = x << 2;
const int out_c_idx = y << 2;
const int bhw4 = bhw << 2;
const int input_offset = y * bhw4 + out_b_idx * 4;
FLOAT4 in0, in1, in2, in3;
if(out_c_idx + 3 < channel && out_b_idx + 3 < bhw){
in0 = vload4(0, input + input_offset);
in1 = vload4(0, input + input_offset + 4);
in2 = vload4(0, input + input_offset + 8);
in3 = vload4(0, input + input_offset + 12);
} else{
if(out_c_idx + 3 < channel){
in0 = vload4(0, input + input_offset);
in1 = out_b_idx + 1 < bhw ? vload4(0, input + input_offset + 4) : 0;
in2 = out_b_idx + 2 < bhw ? vload4(0, input + input_offset + 8) : 0;
in3 = out_b_idx + 3 < bhw ? vload4(0, input + input_offset + 12) : 0;
} else if(out_c_idx + 1 == channel){
in0 = (FLOAT4)(input[input_offset], 0, 0, 0);
in1 = out_b_idx + 1 < bhw ? (FLOAT4)(input[input_offset + 4], 0, 0, 0) : 0;
in2 = out_b_idx + 2 < bhw ? (FLOAT4)(input[input_offset + 8], 0, 0, 0) : 0;
in3 = out_b_idx + 3 < bhw ? (FLOAT4)(input[input_offset + 12], 0, 0, 0) : 0;
} else if(out_c_idx + 2 == channel){
in0 = (FLOAT4)(input[input_offset], input[input_offset + 1], 0, 0);
in1 = out_b_idx + 1 < bhw ? (FLOAT4)(input[input_offset + 4], input[input_offset + 5], 0, 0) : 0;
in2 = out_b_idx + 2 < bhw ? (FLOAT4)(input[input_offset + 8], input[input_offset + 9], 0, 0) : 0;
in3 = out_b_idx + 3 < bhw ? (FLOAT4)(input[input_offset + 12], input[input_offset + 13], 0, 0) : 0;
} else if(out_c_idx + 3 == channel){
in0 = (FLOAT4)(input[input_offset], input[input_offset + 1], input[input_offset + 2], 0);
in1 = out_b_idx + 1 < bhw ? (FLOAT4)(input[input_offset + 4], input[input_offset + 5], input[input_offset + 6], 0) : 0;
in2 = out_b_idx + 2 < bhw ? (FLOAT4)(input[input_offset + 8], input[input_offset + 9], input[input_offset + 10], 0) : 0;
in3 = out_b_idx + 3 < bhw ? (FLOAT4)(input[input_offset + 12], input[input_offset + 13], input[input_offset + 14], 0) : 0;
}
}
int out_offset = out_b_idx * channelAlign + out_c_idx;
vstore4(in0, 0, output + out_offset);
vstore4(in1, 0, output + out_offset + channelAlign);
vstore4(in2, 0, output + out_offset + channelAlign + channelAlign);
vstore4(in3, 0, output + out_offset + channelAlign + channelAlign + channelAlign);
}
__kernel void gemm_nhwc_to_c4nhw4(GLOBAL_SIZE_DIM2
__global const FLOAT* input,
__global FLOAT* output,
__private const int bhw,
__private const int channelAlign
){
const int x = get_global_id(0); //b/4
const int y = get_global_id(1); //c/4
UNIFORM_BOUNDRY_CHECK(x, y);
const int out_b_idx = x << 2;
const int out_c_idx = y << 2;
const int bhw4 = bhw << 2;
const int input_offset = out_b_idx * channelAlign + out_c_idx;
FLOAT4 in0 = vload4(0, input + input_offset);
FLOAT4 in1 = vload4(0, input + input_offset + channelAlign);
FLOAT4 in2 = vload4(0, input + input_offset + channelAlign + channelAlign);
FLOAT4 in3 = vload4(0, input + input_offset + channelAlign + channelAlign + channelAlign);
int out_offset = y * bhw4 + out_b_idx * 4;
vstore4(in0, 0, output + out_offset);
if(out_b_idx + 1 >= bhw) return;
vstore4(in1, 0, output + out_offset + 4);
if(out_b_idx + 2 >= bhw) return;
vstore4(in2, 0, output + out_offset + 8);
if(out_b_idx + 3 >= bhw) return;
vstore4(in3, 0, output + out_offset + 12);
}
#define UCHAR4_TO_FLOAT8(b, scale, offset) \
wei.s0 = (COMPUTE_FLOAT)((b.s0 >> 4) - 8); \
wei.s1 = (COMPUTE_FLOAT)((b.s0 & 15) - 8); \
@ -440,7 +517,6 @@ __kernel void gemm_b4_c8_int4_buf(GLOBAL_SIZE_DIM2
#endif
}
__kernel void gemm_b4_c8_int8_buf(GLOBAL_SIZE_DIM2
__global const FLOAT* input,
#ifdef USE_IMAGE

View File

@ -13,9 +13,9 @@ const char* gemm_conv1x1_buf =
" #ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
" #else\n"
" #if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" #if QUANT_BIT == 8\n"
" __global const char *weight,\n"
" #elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" #else\n"
" __global const uchar *weight,\n"
" #endif\n"
" #endif\n"
@ -31,7 +31,7 @@ const char* gemm_conv1x1_buf =
" const int y=get_global_id(1); //oc\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
"#if QUANT_BIT == 4\n"
" const int ic=x << 2;\n"
" const int oc=y << 3;\n"
" const int output_offset=ic*outputChannelAlign+oc;\n"
@ -83,7 +83,7 @@ const char* gemm_conv1x1_buf =
" vstore8(CONVERT_FLOAT8(weights1),0,output+output_offset+outputChannelAlign);\n"
" vstore8(CONVERT_FLOAT8(weights2),0,output+output_offset+2*outputChannelAlign);\n"
" vstore8(CONVERT_FLOAT8(weights3),0,output+output_offset+3*outputChannelAlign);\n"
"#else\n"
"#elif QUANT_BIT == 8\n"
" const int ic=x << 1;\n"
" const int oc=y << 3;\n"
" const int output_offset=ic*outputChannelAlign+oc;\n"
@ -117,6 +117,81 @@ const char* gemm_conv1x1_buf =
" vstore8(CONVERT_FLOAT8(weights1),0,output+output_offset+outputChannelAlign);\n"
" #endif\n"
"}\n"
"__kernel void gemm_c4nhw4_to_nhwc(GLOBAL_SIZE_DIM2\n"
"__global const FLOAT* input,\n"
"__global FLOAT* output,\n"
"__private const int bhw,\n"
"__private const int channel,\n"
"__private const int channelAlign\n"
"){\n"
" const int x=get_global_id(0); //b/4\n"
" const int y=get_global_id(1); //c/4\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" const int out_b_idx=x << 2;\n"
" const int out_c_idx=y << 2;\n"
" const int bhw4=bhw << 2;\n"
" const int input_offset=y*bhw4+out_b_idx*4;\n"
" FLOAT4 in0,in1,in2,in3;\n"
" if(out_c_idx+3<channel && out_b_idx+3<bhw){\n"
" in0=vload4(0,input+input_offset);\n"
" in1=vload4(0,input+input_offset+4);\n"
" in2=vload4(0,input+input_offset+8);\n"
" in3=vload4(0,input+input_offset+12);\n"
" } else{\n"
" if(out_c_idx+3<channel){\n"
" in0=vload4(0,input+input_offset);\n"
" in1=out_b_idx+1<bhw ? vload4(0,input+input_offset+4) : 0;\n"
" in2=out_b_idx+2<bhw ? vload4(0,input+input_offset+8) : 0;\n"
" in3=out_b_idx+3<bhw ? vload4(0,input+input_offset+12) : 0;\n"
" } else if(out_c_idx+1 == channel){\n"
" in0=(FLOAT4)(input[input_offset],0,0,0);\n"
" in1=out_b_idx+1<bhw ? (FLOAT4)(input[input_offset+4],0,0,0) : 0;\n"
" in2=out_b_idx+2<bhw ? (FLOAT4)(input[input_offset+8],0,0,0) : 0;\n"
" in3=out_b_idx+3<bhw ? (FLOAT4)(input[input_offset+12],0,0,0) : 0;\n"
" } else if(out_c_idx+2 == channel){\n"
" in0=(FLOAT4)(input[input_offset],input[input_offset+1],0,0);\n"
" in1=out_b_idx+1<bhw ? (FLOAT4)(input[input_offset+4],input[input_offset+5],0,0) : 0;\n"
" in2=out_b_idx+2<bhw ? (FLOAT4)(input[input_offset+8],input[input_offset+9],0,0) : 0;\n"
" in3=out_b_idx+3<bhw ? (FLOAT4)(input[input_offset+12],input[input_offset+13],0,0) : 0;\n"
" } else if(out_c_idx+3 == channel){\n"
" in0=(FLOAT4)(input[input_offset],input[input_offset+1],input[input_offset+2],0);\n"
" in1=out_b_idx+1<bhw ? (FLOAT4)(input[input_offset+4],input[input_offset+5],input[input_offset+6],0) : 0;\n"
" in2=out_b_idx+2<bhw ? (FLOAT4)(input[input_offset+8],input[input_offset+9],input[input_offset+10],0) : 0;\n"
" in3=out_b_idx+3<bhw ? (FLOAT4)(input[input_offset+12],input[input_offset+13],input[input_offset+14],0) : 0;\n"
" }\n"
" }\n"
" int out_offset=out_b_idx*channelAlign+out_c_idx;\n"
" vstore4(in0,0,output+out_offset);\n"
" vstore4(in1,0,output+out_offset+channelAlign);\n"
" vstore4(in2,0,output+out_offset+channelAlign+channelAlign);\n"
" vstore4(in3,0,output+out_offset+channelAlign+channelAlign+channelAlign);\n"
"}\n"
"__kernel void gemm_nhwc_to_c4nhw4(GLOBAL_SIZE_DIM2\n"
"__global const FLOAT* input,\n"
"__global FLOAT* output,\n"
"__private const int bhw,\n"
"__private const int channelAlign\n"
"){\n"
" const int x=get_global_id(0); //b/4\n"
" const int y=get_global_id(1); //c/4\n"
" UNIFORM_BOUNDRY_CHECK(x,y);\n"
" const int out_b_idx=x << 2;\n"
" const int out_c_idx=y << 2;\n"
" const int bhw4=bhw << 2;\n"
" const int input_offset=out_b_idx*channelAlign+out_c_idx;\n"
" FLOAT4 in0=vload4(0,input+input_offset);\n"
" FLOAT4 in1=vload4(0,input+input_offset+channelAlign);\n"
" FLOAT4 in2=vload4(0,input+input_offset+channelAlign+channelAlign);\n"
" FLOAT4 in3=vload4(0,input+input_offset+channelAlign+channelAlign+channelAlign);\n"
" int out_offset=y*bhw4+out_b_idx*4;\n"
" vstore4(in0,0,output+out_offset);\n"
" if(out_b_idx+1 >= bhw) return;\n"
" vstore4(in1,0,output+out_offset+4);\n"
" if(out_b_idx+2 >= bhw) return;\n"
" vstore4(in2,0,output+out_offset+8);\n"
" if(out_b_idx+3 >= bhw) return;\n"
" vstore4(in3,0,output+out_offset+12);\n"
"}\n"
"#define UCHAR4_TO_FLOAT8(b, scale, offset) "" wei.s0 = (COMPUTE_FLOAT)((b.s0 >> 4) - 8); "" wei.s1 = (COMPUTE_FLOAT)((b.s0 & 15) - 8); "" wei.s2 = (COMPUTE_FLOAT)((b.s1 >> 4) - 8); "" wei.s3 = (COMPUTE_FLOAT)((b.s1 & 15) - 8); "" wei.s4 = (COMPUTE_FLOAT)((b.s2 >> 4) - 8); "" wei.s5 = (COMPUTE_FLOAT)((b.s2 & 15) - 8); "" wei.s6 = (COMPUTE_FLOAT)((b.s3 >> 4) - 8); "" wei.s7 = (COMPUTE_FLOAT)((b.s3 & 15) - 8); "" wei=wei*scale+offset;\n"
"__kernel void gemm_b4_c8_int4_buf(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* input,\n"

View File

@ -0,0 +1,203 @@
#ifdef MNN_SUPPORT_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
#define GLOBAL_SIZE_DIM2 \
__private int global_size_dim0, __private int global_size_dim1,
#define UNIFORM_BOUNDRY_CHECK(index0, index1) \
if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { \
return; \
}
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef INPUT_CHANNEL_LEAVE
#define PADZEROSVEC(k, channel, data0, data1, data2, data3) \
data0 = (k << 2) < channel ? data0 : 0; \
data1 = (k << 2) + 1 < channel ? data1 : 0; \
data2 = (k << 2) + 2 < channel ? data2 : 0; \
data3 = (k << 2) + 3 < channel ? data3 : 0;
#else
#define PADZEROSVEC(k, channel, data0, data1, data2, data3)
#endif
__kernel void gemm_conv(GLOBAL_SIZE_DIM2
__read_only image2d_t input,
#if QUANT_BIT == 8
__global const char *weight,
__global const float *dequantScaleOffset,
#else
__global const uchar *weight,
__global const float *dequantScaleOffset,
#endif
__read_only image2d_t bias,
__write_only image2d_t output,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int batch
,__private const int blockDim
,__private const int srcChannel
) {
int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b
UNIFORM_BOUNDRY_CHECK(pos.x, pos.y);
FLOAT4 out = RI_F(bias, SAMPLER, (int2)(pos.x, 0));
#if QUANT_BIT == 8
int weight_offset = pos.x * 16;
int weight_oc_offset = dstChannelC4 * 16;
#else
int weight_offset = pos.x * 8;
int weight_oc_offset = dstChannelC4 * 8;
#endif
for (int k = 0; k < srcChannelC4; ++k) {
int kindex = (k * 4) / blockDim * dstChannelC4 * 8;
COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex));
COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6);
COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7);
FLOAT4 in = RI_F(input, SAMPLER, (int2)(k, pos.y));
#if QUANT_BIT == 8
FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset;
#else
uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset);
char16 charWeights = 0;
charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8;
charWeights.s1 = (charWeightsInt4.s0 & 15) - 8;
charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8;
charWeights.s3 = (charWeightsInt4.s1 & 15) - 8;
charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8;
charWeights.s5 = (charWeightsInt4.s2 & 15) - 8;
charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8;
charWeights.s7 = (charWeightsInt4.s3 & 15) - 8;
charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8;
charWeights.s9 = (charWeightsInt4.s4 & 15) - 8;
charWeights.sa = (charWeightsInt4.s5 >> 4) - 8;
charWeights.sb = (charWeightsInt4.s5 & 15) - 8;
charWeights.sc = (charWeightsInt4.s6 >> 4) - 8;
charWeights.sd = (charWeightsInt4.s6 & 15) - 8;
charWeights.se = (charWeightsInt4.s7 >> 4) - 8;
charWeights.sf = (charWeightsInt4.s7 & 15) - 8;
FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset;
#endif
PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef);
out = mad((FLOAT4)in.x, (FLOAT4)weights.s0123, out);
out = mad((FLOAT4)in.y, (FLOAT4)weights.s4567, out);
out = mad((FLOAT4)in.z, (FLOAT4)weights.s89ab, out);
out = mad((FLOAT4)in.w, (FLOAT4)weights.scdef, out);
}
#ifdef RELU
out = fmax(out, (FLOAT4)0);
#endif
#ifdef RELU6
out = clamp(out, (FLOAT4)0, (FLOAT4)6);
#endif
WI_F(output, (int2)(pos.x, pos.y), out);
}
__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2
__read_only image2d_t input,
#if QUANT_BIT == 8
__global const char *weight,
__global const float *dequantScaleOffset,
#else
__global const uchar *weight,
__global const float *dequantScaleOffset,
#endif
__read_only image2d_t bias,
__write_only image2d_t output,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int batch
,__private const int blockDim
,__private const int srcChannel
) {
int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b
UNIFORM_BOUNDRY_CHECK(pos.x, pos.y);
int pos_x = pos.x << 2;
int pos_y = pos.y << 1;
FLOAT4 bias0 = RI_F(bias, SAMPLER, (int2)(pos.x, 0));
FLOAT4 out0 = bias0, out1 = bias0;
#if QUANT_BIT == 8
int weight_offset = pos.x * 16;
int weight_oc_offset = dstChannelC4 * 16;
#else
int weight_offset = pos.x * 8;
int weight_oc_offset = dstChannelC4 * 8;
#endif
for (int k = 0; k < srcChannelC4; ++k) {
int kindex = (k * 4) / blockDim * dstChannelC4 * 8;
COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex));
COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6,
ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6);
COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7,
ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7);
FLOAT4 in0 = RI_F(input, SAMPLER, (int2)(k, pos_y));
FLOAT4 in1 = RI_F(input, SAMPLER, (int2)(k, pos_y + 1));
#if QUANT_BIT == 8
FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset;
#else
uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset);
char16 charWeights = 0;
charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8;
charWeights.s1 = (charWeightsInt4.s0 & 15) - 8;
charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8;
charWeights.s3 = (charWeightsInt4.s1 & 15) - 8;
charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8;
charWeights.s5 = (charWeightsInt4.s2 & 15) - 8;
charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8;
charWeights.s7 = (charWeightsInt4.s3 & 15) - 8;
charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8;
charWeights.s9 = (charWeightsInt4.s4 & 15) - 8;
charWeights.sa = (charWeightsInt4.s5 >> 4) - 8;
charWeights.sb = (charWeightsInt4.s5 & 15) - 8;
charWeights.sc = (charWeightsInt4.s6 >> 4) - 8;
charWeights.sd = (charWeightsInt4.s6 & 15) - 8;
charWeights.se = (charWeightsInt4.s7 >> 4) - 8;
charWeights.sf = (charWeightsInt4.s7 & 15) - 8;
FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset;
#endif
PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef);
out0 = mad((FLOAT4)in0.x, (FLOAT4)weights.s0123, out0);
out0 = mad((FLOAT4)in0.y, (FLOAT4)weights.s4567, out0);
out0 = mad((FLOAT4)in0.z, (FLOAT4)weights.s89ab, out0);
out0 = mad((FLOAT4)in0.w, (FLOAT4)weights.scdef, out0);
out1 = mad((FLOAT4)in1.x, (FLOAT4)weights.s0123, out1);
out1 = mad((FLOAT4)in1.y, (FLOAT4)weights.s4567, out1);
out1 = mad((FLOAT4)in1.z, (FLOAT4)weights.s89ab, out1);
out1 = mad((FLOAT4)in1.w, (FLOAT4)weights.scdef, out1);
}
#ifdef RELU
out0 = fmax(out0, (FLOAT4)0);
out1 = fmax(out1, (FLOAT4)0);
#endif
#ifdef RELU6
out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
#endif
WI_F(output, (int2)(pos.x, pos_y), out0);
if(pos_y + 1 < batch)
WI_F(output, (int2)(pos.x, pos_y + 1), out1);
}

View File

@ -0,0 +1,185 @@
#include "opencl_source_map.hpp"
namespace MNN {
const char* gemm_int =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel void gemm_conv(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
"#if QUANT_BIT == 8\n"
" __global const char *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#else\n"
" __global const uchar *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
" ,__private const int blockDim\n"
" ,__private const int srcChannel\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" FLOAT4 out=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
"#if QUANT_BIT == 8\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#else \n"
" int weight_offset=pos.x*8;\n"
" int weight_oc_offset=dstChannelC4*8;\n"
"#endif\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(k,pos.y));\n"
"#if QUANT_BIT == 8\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
"#else\n"
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
" char16 charWeights=0;\n"
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
"#endif\n"
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
" \n"
" out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n"
" out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n"
" out=mad((FLOAT4)in.z,(FLOAT4)weights.s89ab,out);\n"
" out=mad((FLOAT4)in.w,(FLOAT4)weights.scdef,out);\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out=fmax(out,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out=clamp(out,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" WI_F(output,(int2)(pos.x,pos.y),out);\n"
"}\n"
"__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
"#if QUANT_BIT == 8\n"
" __global const char *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#else\n"
" __global const uchar *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
" ,__private const int blockDim\n"
" ,__private const int srcChannel\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int pos_x=pos.x << 2;\n"
" int pos_y=pos.y << 1;\n"
" FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
" FLOAT4 out0=bias0,out1=bias0;\n"
" \n"
"#if QUANT_BIT == 8\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#else\n"
" int weight_offset=pos.x*8;\n"
" int weight_oc_offset=dstChannelC4*8;\n"
"#endif\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
" FLOAT4 in0=RI_F(input,SAMPLER,(int2)(k,pos_y));\n"
" FLOAT4 in1=RI_F(input,SAMPLER,(int2)(k,pos_y+1));\n"
"#if QUANT_BIT == 8\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
"#else\n"
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
" char16 charWeights=0;\n"
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
"#endif\n"
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
" \n"
" out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n"
" out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n"
" out0=mad((FLOAT4)in0.z,(FLOAT4)weights.s89ab,out0);\n"
" out0=mad((FLOAT4)in0.w,(FLOAT4)weights.scdef,out0);\n"
" \n"
" out1=mad((FLOAT4)in1.x,(FLOAT4)weights.s0123,out1);\n"
" out1=mad((FLOAT4)in1.y,(FLOAT4)weights.s4567,out1);\n"
" out1=mad((FLOAT4)in1.z,(FLOAT4)weights.s89ab,out1);\n"
" out1=mad((FLOAT4)in1.w,(FLOAT4)weights.scdef,out1);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" WI_F(output,(int2)(pos.x,pos_y),out0);\n"
" if(pos_y+1<batch)\n"
" WI_F(output,(int2)(pos.x,pos_y+1),out1);\n"
"}\n"
;
}

View File

@ -248,79 +248,21 @@ const char* gemm =
"#endif\n"
"__kernel void gemm_conv(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#else\n"
" __global const FLOAT *weight,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int srcChannel\n"
"#endif\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" FLOAT4 out=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_offset=pos.x*8;\n"
" int weight_oc_offset=dstChannelC4*8;\n"
"#else\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#endif\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
"#endif\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(k,pos.y));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
" char16 charWeights=0;\n"
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
" \n"
"#else\n"
" FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n"
"#endif\n"
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
" \n"
" out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n"
" out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n"
@ -338,24 +280,12 @@ const char* gemm =
"}\n"
"__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *weight,\n"
" __global const float *dequantScaleOffset,\n"
"#else\n"
" __global const FLOAT *weight,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int srcChannel\n"
"#endif\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
@ -364,57 +294,12 @@ const char* gemm =
" FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
" FLOAT4 out0=bias0,out1=bias0;\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_offset=pos.x*8;\n"
" int weight_oc_offset=dstChannelC4*8;\n"
"#else\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
"#endif\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
"#endif\n"
" FLOAT4 in0=RI_F(input,SAMPLER,(int2)(k,pos_y));\n"
" FLOAT4 in1=RI_F(input,SAMPLER,(int2)(k,pos_y+1));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
" char16 charWeights=0;\n"
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
"#else\n"
" FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n"
"#endif\n"
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
" \n"
" out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n"
" out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n"

View File

@ -6,6 +6,9 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
#define GLOBAL_SIZE_DIM_2 \
__private int global_size_dim0, __private int global_size_dim1,
#define GLOBAL_SIZE_DIM_3 \
__private int global_size_dim0, __private int global_size_dim1, __private int global_size_dim2,
#define UNIFORM_BOUNDRY_CHECK_2(index0, index1) \
if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { \
return; \
@ -24,16 +27,22 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
#if WGS >= 8
__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3
__global const FLOAT* input,
#ifdef USE_IMAGE
__read_only image2d_t weight,
#else
#if QUANT_BIT == 8
__global const char *weight,
#else
__global const uchar *weight,
#endif
#endif
__global const FLOAT *dequantScaleOffset,
__global const FLOAT *bias,
__global FLOAT* output,
__private const int dstChannelAlign,
__private const int srcChannelAlign,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int srcChannel,
@ -43,32 +52,101 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
const int lid = get_local_id(0);
const int oc = get_global_id(1); //oc/8
const int oc8 = oc << 3;
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop = max((srcChannel + 4 - 1) / 4 - 1, 0);
#else
const int loop = (srcChannel + 4 - 1) / 4;
#endif
__local COMPUTE_FLOAT8 sum[WGS];
COMPUTE_FLOAT8 out0 = 0;
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 16;
#endif
#if QUANT_BIT == 8
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop = max((srcChannel + 2 - 1) / 2 - 1, 0);
#else
const int loop = (srcChannel + 2 - 1) / 2;
#endif
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 32;
#endif
#else
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop = max((srcChannel + 4 - 1) / 4 - 1, 0);
#else
const int loop = (srcChannel + 4 - 1) / 4;
#endif
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 16;
#endif
#endif
COMPUTE_FLOAT8 out0 = 0;
int input_offset = 0, output_offset = oc8;
__local COMPUTE_FLOAT8 sum0[WGS];
#ifdef COMPUTE_BATCH
const int out_b_idx = get_global_id(2) << 2; //b/4
__local COMPUTE_FLOAT8 sum1[WGS];
__local COMPUTE_FLOAT8 sum2[WGS];
__local COMPUTE_FLOAT8 sum3[WGS];
COMPUTE_FLOAT8 out1 = 0, out2 = 0, out3 = 0;
input_offset = out_b_idx * srcChannelAlign;
output_offset = oc8 + out_b_idx * dstChannelAlign;
#endif
for(int j = lid; j < loop; j+=WGS){
#if QUANT_BIT == 8
int k2 = j << 1;
COMPUTE_FLOAT16 scale, offset;
{
#ifdef ASYMMETRIC
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace);
offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf);
#else
COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset);
offset = 0;
#endif
}
COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + k2));
#ifdef COMPUTE_BATCH
COMPUTE_FLOAT2 in1 = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + srcChannelAlign + k2));
COMPUTE_FLOAT2 in2 = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + srcChannelAlign * 2 + k2));
COMPUTE_FLOAT2 in3 = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + srcChannelAlign * 3 + k2));
#endif
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(j, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(j, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0);
#ifdef COMPUTE_BATCH
out1 = mad((COMPUTE_FLOAT8)in1.s0, wei.s01234567, out1);
out2 = mad((COMPUTE_FLOAT8)in2.s0, wei.s01234567, out2);
out3 = mad((COMPUTE_FLOAT8)in3.s0, wei.s01234567, out3);
#endif
}
{
out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0);
#ifdef COMPUTE_BATCH
out1 = mad((COMPUTE_FLOAT8)in1.s1, wei.s89abcdef, out1);
out2 = mad((COMPUTE_FLOAT8)in2.s1, wei.s89abcdef, out2);
out3 = mad((COMPUTE_FLOAT8)in3.s1, wei.s89abcdef, out3);
#endif
}
#else
int k4 = j << 2;
#ifdef ASYMMETRIC
#ifdef ASYMMETRIC
COMPUTE_FLOAT8 scale, offset;
{
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k4 / blockDim) * dstChannelC4 * 8)) / coef);
scale = scaleOffset.s02468ace;
offset = scaleOffset.s13579bdf;
}
#else
#else
COMPUTE_FLOAT8 scale = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k4 / blockDim) * dstChannelC4 * 4)) / coef);
COMPUTE_FLOAT8 offset = 0;
#endif
#endif
COMPUTE_FLOAT8 wei;
COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4));
COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4 + input_offset));
#ifdef COMPUTE_BATCH
COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + srcChannelAlign + k4));
COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + srcChannelAlign * 2 + k4));
COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + srcChannelAlign * 3 + k4));
#endif
#ifdef USE_IMAGE
uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(j, oc)));
#else
@ -77,39 +155,83 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
{
UCHAR4_TO_CHAR8(charWeightsInt40.s0123, scale, offset);
out0 = mad((COMPUTE_FLOAT8)in.s0, wei, out0);
#ifdef COMPUTE_BATCH
out1 = mad((COMPUTE_FLOAT8)in1.s0, wei, out1);
out2 = mad((COMPUTE_FLOAT8)in2.s0, wei, out2);
out3 = mad((COMPUTE_FLOAT8)in3.s0, wei, out3);
#endif
}
{
UCHAR4_TO_CHAR8(charWeightsInt40.s4567, scale, offset);
out0 = mad((COMPUTE_FLOAT8)in.s1, wei, out0);
#ifdef COMPUTE_BATCH
out1 = mad((COMPUTE_FLOAT8)in1.s1, wei, out1);
out2 = mad((COMPUTE_FLOAT8)in2.s1, wei, out2);
out3 = mad((COMPUTE_FLOAT8)in3.s1, wei, out3);
#endif
}
{
UCHAR4_TO_CHAR8(charWeightsInt40.s89ab, scale, offset);
out0 = mad((COMPUTE_FLOAT8)in.s2, wei, out0);
#ifdef COMPUTE_BATCH
out1 = mad((COMPUTE_FLOAT8)in1.s2, wei, out1);
out2 = mad((COMPUTE_FLOAT8)in2.s2, wei, out2);
out3 = mad((COMPUTE_FLOAT8)in3.s2, wei, out3);
#endif
}
{
UCHAR4_TO_CHAR8(charWeightsInt40.scdef, scale, offset);
out0 = mad((COMPUTE_FLOAT8)in.s3, wei, out0);
#ifdef COMPUTE_BATCH
out1 = mad((COMPUTE_FLOAT8)in1.s3, wei, out1);
out2 = mad((COMPUTE_FLOAT8)in2.s3, wei, out2);
out3 = mad((COMPUTE_FLOAT8)in3.s3, wei, out3);
#endif
}
#endif
}
#if INPUT_CHANNEL_LEAVES_NUM != 0
{
#if QUANT_BIT == 8
int k2 = loop << 1;
COMPUTE_FLOAT16 scale, offset;
{
#ifdef ASYMMETRIC
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace);
offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf);
#else
COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset);
offset = 0;
#endif
}
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(loop, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(loop, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)input[k2], wei.s01234567, out0);
}
#else
int k4 = loop << 2;
#ifdef ASYMMETRIC
#ifdef ASYMMETRIC
COMPUTE_FLOAT8 scale, offset;
{
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k4 / blockDim) * dstChannelC4 * 8)) / coef);
scale = scaleOffset.s02468ace;
offset = scaleOffset.s13579bdf;
}
#else
#else
COMPUTE_FLOAT8 scale = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k4 / blockDim) * dstChannelC4 * 4)) / coef);
COMPUTE_FLOAT8 offset = 0;
#endif
#endif
COMPUTE_FLOAT8 wei;
#ifdef USE_IMAGE
uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(loop, oc)));
#else
uchar16 charWeightsInt40 = vload16(j, weight + weight_offset);
uchar16 charWeightsInt40 = vload16(loop, weight + weight_offset);
#endif
{
UCHAR4_TO_CHAR8(charWeightsInt40.s0123, scale, offset);
@ -127,17 +249,28 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
out0 = mad((COMPUTE_FLOAT8)input[k4 + 2], wei, out0);
}
#endif
#endif
}
#endif
sum[lid] = out0;
sum0[lid] = out0;
#ifdef COMPUTE_BATCH
sum1[lid] = out1; sum2[lid] = out2; sum3[lid] = out3;
#endif
barrier(CLK_LOCAL_MEM_FENCE);
for(int i = WGS/2; i > 0; i /= 2){
if (lid < i)
sum[lid] = sum[lid] + sum[lid + i];
if (lid < i){
sum0[lid] = sum0[lid] + sum0[lid + i];
#ifdef COMPUTE_BATCH
sum1[lid] = sum1[lid] + sum1[lid + i];
sum2[lid] = sum2[lid] + sum2[lid + i];
sum3[lid] = sum3[lid] + sum3[lid + i];
#endif
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if(lid == 0){
out0 = sum[0] + CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8));
COMPUTE_FLOAT8 vBias = CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8));
out0 = sum0[0] + vBias;
#ifdef RELU
out0 = fmax(out0, (COMPUTE_FLOAT8)0);
#endif
@ -146,132 +279,43 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);
#endif
#ifdef OUTPUT_CHANNEL_LEAVES
vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8);
vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + output_offset);
if(oc8 + 4 < dstChannelC4 * 4)
vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + oc8 + 4);
vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + 4 + output_offset);
#else
vstore8(CONVERT_FLOAT8(out0), 0, output + oc8);
vstore8(CONVERT_FLOAT8(out0), 0, output + output_offset);
#endif
#ifdef COMPUTE_BATCH
out1 = sum1[0] + vBias; out2 = sum2[0] + vBias; out3 = sum3[0] + vBias;
#ifdef RELU
out1 = fmax(out1, (COMPUTE_FLOAT8)0);out2 = fmax(out2, (COMPUTE_FLOAT8)0);out3 = fmax(out3, (COMPUTE_FLOAT8)0);
#endif
#ifdef RELU6
out1 = clamp(out1, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);out2 = clamp(out2, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);out3 = clamp(out3, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);
#endif
vstore8(CONVERT_FLOAT8(out1), 0, output + output_offset + dstChannelAlign);
vstore8(CONVERT_FLOAT8(out2), 0, output + output_offset + dstChannelAlign + dstChannelAlign);
vstore8(CONVERT_FLOAT8(out3), 0, output + output_offset + dstChannelAlign + dstChannelAlign + dstChannelAlign);
#endif
}
}
__kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2
#else
__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3
__global const FLOAT* input,
#ifdef USE_IMAGE
__read_only image2d_t weight,
#else
#if QUANT_BIT == 8
__global const char *weight,
#endif
__global const FLOAT *dequantScaleOffset,
__global const FLOAT *bias,
__global FLOAT* output,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int srcChannel,
__private const int blockNum,
__private const int blockDim,
__private const float coef) {
const int lid = get_local_id(0);
const int oc = get_global_id(1); //oc/8
const int oc8 = oc << 3;
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop = max((srcChannel + 2 - 1) / 2 - 1, 0);
#else
const int loop = (srcChannel + 2 - 1) / 2;
#endif
__local COMPUTE_FLOAT8 sum[WGS];
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 32;
#endif
COMPUTE_FLOAT8 out0 = 0;
for(int j = lid; j < loop; j+=WGS){
int k2 = j << 1;
COMPUTE_FLOAT16 scale, offset;
{
#ifdef ASYMMETRIC
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace);
offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf);
#else
COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset);
offset = 0;
#endif
}
COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + k2));
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(j, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(j, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0);
}
{
out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0);
}
}
#if INPUT_CHANNEL_LEAVES_NUM != 0
{
int k2 = loop << 1;
COMPUTE_FLOAT16 scale, offset;
{
#ifdef ASYMMETRIC
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace);
offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf);
#else
COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset);
offset = 0;
#endif
}
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(loop, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(j, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)input[k2], wei.s01234567, out0);
}
}
#endif
sum[lid] = out0;
barrier(CLK_LOCAL_MEM_FENCE);
for(int i = WGS/2; i > 0; i /= 2){
if (lid < i)
sum[lid] = sum[lid] + sum[lid + i];
barrier(CLK_LOCAL_MEM_FENCE);
}
if(lid == 0){
out0 = sum[0] + CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8));
#ifdef RELU
out0 = fmax(out0, (COMPUTE_FLOAT8)0);
#endif
#ifdef RELU6
out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);
#endif
#ifdef OUTPUT_CHANNEL_LEAVES
vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8);
if(oc8 + 4 < dstChannelC4 * 4)
vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + oc8 + 4);
#else
vstore8(CONVERT_FLOAT8(out0), 0, output + oc8);
#endif
}
}
#else
__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
__global const FLOAT* input,
#ifdef USE_IMAGE
__read_only image2d_t weight,
#else
#else
__global const uchar *weight,
#endif
#endif
__global const FLOAT *dequantScaleOffset,
__global const FLOAT *bias,
__global FLOAT* output,
__private const int dstChannelAlign,
__private const int srcChannelAlign,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int srcChannel,
@ -283,29 +327,83 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
UNIFORM_BOUNDRY_CHECK_2(ic, oc);
const int oc8 = oc << 3;
const int loop = (blockDim + 4 - 1) / 4;
#if INPUT_CHANNEL_LEAVES_NUM != 0
#if QUANT_BIT == 8
const int loop = (blockDim + 2 - 1) / 2;
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop_end = max(loop - 1, 0);
#else
#else
const int loop_end = loop;
#endif
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 32;
#endif
#else
const int loop = (blockDim + 4 - 1) / 4;
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop_end = max(loop - 1, 0);
#else
const int loop_end = loop;
#endif
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 16;
#endif
#endif
COMPUTE_FLOAT8 out0 = CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8));
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 16;
#endif
for (int i = 0; i < blockNum; i++){
#ifdef ASYMMETRIC
#if QUANT_BIT == 8
COMPUTE_FLOAT16 scale, offset;
{
#ifdef ASYMMETRIC
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + i * dstChannelC4 * 8)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace);
offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf);
#else
COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + i * dstChannelC4 * 4)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset);
offset = 0;
#endif
}
for (int j = 0; j < loop_end; j++) {
int k = i * loop + j;
COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + (k << 1)));
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0);
}
{
out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0);
}
}
#if INPUT_CHANNEL_LEAVES_NUM != 0
{
int k = i * loop + loop_end;
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)input[k << 1], wei.s01234567, out0);
}
}
#endif
#else
#ifdef ASYMMETRIC
COMPUTE_FLOAT8 scale, offset;
{
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + i * dstChannelC4 * 8)) / coef);
scale = scaleOffset.s02468ace;
offset = scaleOffset.s13579bdf;
}
#else
#else
COMPUTE_FLOAT8 scale = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + i * dstChannelC4 * 4)) / coef);
COMPUTE_FLOAT8 offset = 0;
#endif
#endif
for (int j = 0; j < loop_end; j++) {
int k = i * loop + j;
COMPUTE_FLOAT8 wei;
@ -360,94 +458,7 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2
#endif
}
#endif
}
#ifdef RELU
out0 = fmax(out0, (COMPUTE_FLOAT8)0);
#endif
#ifdef RELU6
out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);
#endif
#ifdef OUTPUT_CHANNEL_LEAVES
vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8);
if(oc8 + 4 < dstChannelC4 * 4)
vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + oc8 + 4);
#else
vstore8(CONVERT_FLOAT8(out0), 0, output + oc8);
#endif
}
__kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2
__global const FLOAT* input,
#ifdef USE_IMAGE
__read_only image2d_t weight,
#else
__global const char *weight,
#endif
__global const FLOAT *dequantScaleOffset,
__global const FLOAT *bias,
__global FLOAT* output,
__private const int dstChannelC4,
__private const int srcChannelC4,
__private const int srcChannel,
__private const int blockNum,
__private const int blockDim,
__private const float coef) {
const int ic = get_global_id(0);
const int oc = get_global_id(1); //oc/8
UNIFORM_BOUNDRY_CHECK_2(ic, oc);
const int oc8 = oc << 3;
const int loop = (blockDim + 2 - 1) / 2;
#if INPUT_CHANNEL_LEAVES_NUM != 0
const int loop_end = max(loop - 1, 0);
#else
const int loop_end = loop;
#endif
#ifndef USE_IMAGE
const int weight_offset = oc * srcChannelC4 * 32;
#endif
COMPUTE_FLOAT8 out0 = CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8));
for (int i = 0; i < blockNum; i++){
COMPUTE_FLOAT16 scale, offset;
{
#ifdef ASYMMETRIC
COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + i * dstChannelC4 * 8)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace);
offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf);
#else
COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + i * dstChannelC4 * 4)) / coef);
scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset);
offset = 0;
#endif
}
for (int j = 0; j < loop_end; j++) {
int k = i * loop + j;
COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + (k << 1)));
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0);
}
{
out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0);
}
}
#if INPUT_CHANNEL_LEAVES_NUM != 0
{
int k = i * loop + loop_end;
#ifdef USE_IMAGE
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset;
#else
COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset;
#endif
{
out0 = mad((COMPUTE_FLOAT8)input[k << 1], wei.s01234567, out0);
}
}
#endif
}
#ifdef RELU
out0 = fmax(out0, (COMPUTE_FLOAT8)0);
@ -456,7 +467,6 @@ __kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2
#ifdef RELU6
out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);
#endif
#ifdef OUTPUT_CHANNEL_LEAVES
vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8);
if(oc8 + 4 < dstChannelC4 * 4)

View File

@ -7,19 +7,26 @@ const char* gemv_conv1x1_buf =
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define GLOBAL_SIZE_DIM_2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define GLOBAL_SIZE_DIM_3 "" __private int global_size_dim0,__private int global_size_dim1,__private int global_size_dim2,\n"
"#define UNIFORM_BOUNDRY_CHECK_2(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"#define UCHAR4_TO_CHAR8(b, scale, offset) "" wei.s0 = (COMPUTE_FLOAT)((b.s0 >> 4) - 8); "" wei.s1 = (COMPUTE_FLOAT)((b.s0 & 15) - 8); "" wei.s2 = (COMPUTE_FLOAT)((b.s1 >> 4) - 8); "" wei.s3 = (COMPUTE_FLOAT)((b.s1 & 15) - 8); "" wei.s4 = (COMPUTE_FLOAT)((b.s2 >> 4) - 8); "" wei.s5 = (COMPUTE_FLOAT)((b.s2 & 15) - 8); "" wei.s6 = (COMPUTE_FLOAT)((b.s3 >> 4) - 8); "" wei.s7 = (COMPUTE_FLOAT)((b.s3 & 15) - 8); "" wei=wei*scale+offset;\n"
"#if WGS >= 8\n"
"__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2\n"
"__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
" #if QUANT_BIT == 8\n"
" __global const char *weight,\n"
" #else\n"
" __global const uchar *weight,\n"
" #endif\n"
"#endif\n"
" __global const FLOAT *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelAlign,\n"
" __private const int srcChannelAlign,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
@ -29,32 +36,100 @@ const char* gemv_conv1x1_buf =
" const int lid=get_local_id(0);\n"
" const int oc=get_global_id(1); //oc/8\n"
" const int oc8=oc << 3;\n"
"#if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop=max((srcChannel+4-1)/4-1,0);\n"
"#else\n"
" const int loop=(srcChannel+4-1)/4;\n"
"#endif\n"
" __local COMPUTE_FLOAT8 sum[WGS];\n"
" COMPUTE_FLOAT8 out0=0;\n"
"#ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*16;\n"
"#endif\n"
" \n"
"#if QUANT_BIT == 8\n"
" #if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop=max((srcChannel+2-1)/2-1,0);\n"
" #else\n"
" const int loop=(srcChannel+2-1)/2;\n"
" #endif\n"
" #ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*32;\n"
" #endif\n"
"#else\n"
" #if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop=max((srcChannel+4-1)/4-1,0);\n"
" #else\n"
" const int loop=(srcChannel+4-1)/4;\n"
" #endif\n"
" #ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*16;\n"
" #endif\n"
"#endif\n"
" COMPUTE_FLOAT8 out0=0;\n"
" int input_offset=0,output_offset=oc8;\n"
" __local COMPUTE_FLOAT8 sum0[WGS];\n"
"#ifdef COMPUTE_BATCH\n"
" const int out_b_idx=get_global_id(2) << 2; //b/4\n"
" __local COMPUTE_FLOAT8 sum1[WGS];\n"
" __local COMPUTE_FLOAT8 sum2[WGS];\n"
" __local COMPUTE_FLOAT8 sum3[WGS];\n"
" COMPUTE_FLOAT8 out1=0,out2=0,out3=0;\n"
" input_offset=out_b_idx*srcChannelAlign;\n"
" output_offset=oc8+out_b_idx*dstChannelAlign;\n"
"#endif\n"
" for(int j=lid; j<loop; j+=WGS){\n"
" #if QUANT_BIT == 8\n"
" int k2=j << 1;\n"
" COMPUTE_FLOAT16 scale,offset;\n"
" {\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+(k2/blockDim)*dstChannelC4*8))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset.s02468ace,scaleOffset.s02468ace);\n"
" offset=(COMPUTE_FLOAT16)(scaleOffset.s13579bdf,scaleOffset.s13579bdf);\n"
" #else\n"
" COMPUTE_FLOAT8 scaleOffset=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+(k2/blockDim)*dstChannelC4*4))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset,scaleOffset);\n"
" offset=0;\n"
" #endif\n"
" }\n"
" COMPUTE_FLOAT2 in=CONVERT_COMPUTE_FLOAT2(vload2(0,input+input_offset+k2));\n"
" #ifdef COMPUTE_BATCH\n"
" COMPUTE_FLOAT2 in1=CONVERT_COMPUTE_FLOAT2(vload2(0,input+input_offset+srcChannelAlign+k2));\n"
" COMPUTE_FLOAT2 in2=CONVERT_COMPUTE_FLOAT2(vload2(0,input+input_offset+srcChannelAlign*2+k2));\n"
" COMPUTE_FLOAT2 in3=CONVERT_COMPUTE_FLOAT2(vload2(0,input+input_offset+srcChannelAlign*3+k2));\n"
" #endif\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(j,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(j,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s0,wei.s01234567,out0);\n"
" #ifdef COMPUTE_BATCH\n"
" out1=mad((COMPUTE_FLOAT8)in1.s0,wei.s01234567,out1);\n"
" out2=mad((COMPUTE_FLOAT8)in2.s0,wei.s01234567,out2);\n"
" out3=mad((COMPUTE_FLOAT8)in3.s0,wei.s01234567,out3);\n"
" #endif\n"
" }\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s1,wei.s89abcdef,out0);\n"
" #ifdef COMPUTE_BATCH\n"
" out1=mad((COMPUTE_FLOAT8)in1.s1,wei.s89abcdef,out1);\n"
" out2=mad((COMPUTE_FLOAT8)in2.s1,wei.s89abcdef,out2);\n"
" out3=mad((COMPUTE_FLOAT8)in3.s1,wei.s89abcdef,out3);\n"
" #endif\n"
" }\n"
" #else\n"
" int k4=j << 2;\n"
"#ifdef ASYMMETRIC\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT8 scale,offset;\n"
" {\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+(k4/blockDim)*dstChannelC4*8))/coef);\n"
" scale=scaleOffset.s02468ace;\n"
" offset=scaleOffset.s13579bdf;\n"
" }\n"
"#else\n"
" #else\n"
" COMPUTE_FLOAT8 scale=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+(k4/blockDim)*dstChannelC4*4))/coef);\n"
" COMPUTE_FLOAT8 offset=0;\n"
"#endif\n"
" #endif\n"
" COMPUTE_FLOAT8 wei;\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k4));\n"
" COMPUTE_FLOAT4 in=CONVERT_COMPUTE_FLOAT4(vload4(0,input+k4+input_offset));\n"
" #ifdef COMPUTE_BATCH\n"
" COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+srcChannelAlign+k4));\n"
" COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+srcChannelAlign*2+k4));\n"
" COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4(vload4(0,input+input_offset+srcChannelAlign*3+k4));\n"
" #endif\n"
" #ifdef USE_IMAGE\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(j,oc)));\n"
" #else\n"
@ -63,39 +138,83 @@ const char* gemv_conv1x1_buf =
" {\n"
" UCHAR4_TO_CHAR8(charWeightsInt40.s0123,scale,offset);\n"
" out0=mad((COMPUTE_FLOAT8)in.s0,wei,out0);\n"
" #ifdef COMPUTE_BATCH\n"
" out1=mad((COMPUTE_FLOAT8)in1.s0,wei,out1);\n"
" out2=mad((COMPUTE_FLOAT8)in2.s0,wei,out2);\n"
" out3=mad((COMPUTE_FLOAT8)in3.s0,wei,out3);\n"
" #endif\n"
" }\n"
" {\n"
" UCHAR4_TO_CHAR8(charWeightsInt40.s4567,scale,offset);\n"
" out0=mad((COMPUTE_FLOAT8)in.s1,wei,out0);\n"
" #ifdef COMPUTE_BATCH\n"
" out1=mad((COMPUTE_FLOAT8)in1.s1,wei,out1);\n"
" out2=mad((COMPUTE_FLOAT8)in2.s1,wei,out2);\n"
" out3=mad((COMPUTE_FLOAT8)in3.s1,wei,out3);\n"
" #endif\n"
" }\n"
" {\n"
" UCHAR4_TO_CHAR8(charWeightsInt40.s89ab,scale,offset);\n"
" out0=mad((COMPUTE_FLOAT8)in.s2,wei,out0);\n"
" #ifdef COMPUTE_BATCH\n"
" out1=mad((COMPUTE_FLOAT8)in1.s2,wei,out1);\n"
" out2=mad((COMPUTE_FLOAT8)in2.s2,wei,out2);\n"
" out3=mad((COMPUTE_FLOAT8)in3.s2,wei,out3);\n"
" #endif\n"
" }\n"
" {\n"
" UCHAR4_TO_CHAR8(charWeightsInt40.scdef,scale,offset);\n"
" out0=mad((COMPUTE_FLOAT8)in.s3,wei,out0);\n"
" #ifdef COMPUTE_BATCH\n"
" out1=mad((COMPUTE_FLOAT8)in1.s3,wei,out1);\n"
" out2=mad((COMPUTE_FLOAT8)in2.s3,wei,out2);\n"
" out3=mad((COMPUTE_FLOAT8)in3.s3,wei,out3);\n"
" #endif\n"
" }\n"
" #endif\n"
" }\n"
"#if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" {\n"
" #if QUANT_BIT == 8\n"
" int k2=loop << 1;\n"
" COMPUTE_FLOAT16 scale,offset;\n"
" {\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+(k2/blockDim)*dstChannelC4*8))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset.s02468ace,scaleOffset.s02468ace);\n"
" offset=(COMPUTE_FLOAT16)(scaleOffset.s13579bdf,scaleOffset.s13579bdf);\n"
" #else\n"
" COMPUTE_FLOAT8 scaleOffset=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+(k2/blockDim)*dstChannelC4*4))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset,scaleOffset);\n"
" offset=0;\n"
" #endif\n"
" }\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(loop,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(loop,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)input[k2],wei.s01234567,out0);\n"
" }\n"
" #else\n"
" int k4=loop << 2;\n"
"#ifdef ASYMMETRIC\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT8 scale,offset;\n"
" {\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+(k4/blockDim)*dstChannelC4*8))/coef);\n"
" scale=scaleOffset.s02468ace;\n"
" offset=scaleOffset.s13579bdf;\n"
" }\n"
"#else\n"
" #else\n"
" COMPUTE_FLOAT8 scale=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+(k4/blockDim)*dstChannelC4*4))/coef);\n"
" COMPUTE_FLOAT8 offset=0;\n"
"#endif\n"
" #endif\n"
" COMPUTE_FLOAT8 wei;\n"
" #ifdef USE_IMAGE\n"
" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(loop,oc)));\n"
" #else\n"
" uchar16 charWeightsInt40=vload16(j,weight+weight_offset);\n"
" uchar16 charWeightsInt40=vload16(loop,weight+weight_offset);\n"
" #endif\n"
" {\n"
" UCHAR4_TO_CHAR8(charWeightsInt40.s0123,scale,offset);\n"
@ -113,17 +232,28 @@ const char* gemv_conv1x1_buf =
" out0=mad((COMPUTE_FLOAT8)input[k4+2],wei,out0);\n"
" }\n"
" #endif\n"
" #endif\n"
" }\n"
"#endif\n"
" sum[lid]=out0;\n"
" sum0[lid]=out0;\n"
" #ifdef COMPUTE_BATCH\n"
" sum1[lid]=out1; sum2[lid]=out2; sum3[lid]=out3;\n"
" #endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=WGS/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" if (lid<i){\n"
" sum0[lid]=sum0[lid]+sum0[lid+i];\n"
" #ifdef COMPUTE_BATCH\n"
" sum1[lid]=sum1[lid]+sum1[lid+i];\n"
" sum2[lid]=sum2[lid]+sum2[lid+i];\n"
" sum3[lid]=sum3[lid]+sum3[lid+i];\n"
" #endif\n"
" }\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" if(lid == 0){\n"
" out0=sum[0]+CONVERT_COMPUTE_FLOAT8(vload8(0,bias+oc8));\n"
" COMPUTE_FLOAT8 vBias=CONVERT_COMPUTE_FLOAT8(vload8(0,bias+oc8));\n"
" out0=sum0[0]+vBias;\n"
" #ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT8)0);\n"
" #endif\n"
@ -131,130 +261,43 @@ const char* gemv_conv1x1_buf =
" out0=clamp(out0,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
" #endif\n"
" #ifdef OUTPUT_CHANNEL_LEAVES\n"
" vstore4(CONVERT_FLOAT4(out0.s0123),0,output+oc8);\n"
" vstore4(CONVERT_FLOAT4(out0.s0123),0,output+output_offset);\n"
" if(oc8+4<dstChannelC4*4)\n"
" vstore4(CONVERT_FLOAT4(out0.s4567),0,output+oc8+4);\n"
" vstore4(CONVERT_FLOAT4(out0.s4567),0,output+4+output_offset);\n"
" #else\n"
" vstore8(CONVERT_FLOAT8(out0),0,output+oc8);\n"
" vstore8(CONVERT_FLOAT8(out0),0,output+output_offset);\n"
" #endif\n"
" #ifdef COMPUTE_BATCH\n"
" out1=sum1[0]+vBias; out2=sum2[0]+vBias; out3=sum3[0]+vBias;\n"
" #ifdef RELU\n"
" out1=fmax(out1,(COMPUTE_FLOAT8)0);out2=fmax(out2,(COMPUTE_FLOAT8)0);out3=fmax(out3,(COMPUTE_FLOAT8)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" out1=clamp(out1,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);out2=clamp(out2,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);out3=clamp(out3,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
" #endif\n"
" vstore8(CONVERT_FLOAT8(out1),0,output+output_offset+dstChannelAlign);\n"
" vstore8(CONVERT_FLOAT8(out2),0,output+output_offset+dstChannelAlign+dstChannelAlign);\n"
" vstore8(CONVERT_FLOAT8(out3),0,output+output_offset+dstChannelAlign+dstChannelAlign+dstChannelAlign);\n"
" #endif\n"
" }\n"
"}\n"
"__kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2\n"
"#else\n"
"__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
" #if QUANT_BIT == 8\n"
" __global const char *weight,\n"
"#endif\n"
" __global const FLOAT *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
" __private const int blockNum,\n"
" __private const int blockDim,\n"
" __private const float coef) {\n"
" const int lid=get_local_id(0);\n"
" const int oc=get_global_id(1); //oc/8\n"
" const int oc8=oc << 3;\n"
"#if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop=max((srcChannel+2-1)/2-1,0);\n"
"#else\n"
" const int loop=(srcChannel+2-1)/2;\n"
"#endif\n"
" __local COMPUTE_FLOAT8 sum[WGS];\n"
"#ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*32;\n"
"#endif\n"
" COMPUTE_FLOAT8 out0=0;\n"
" for(int j=lid; j<loop; j+=WGS){\n"
" int k2=j << 1;\n"
" COMPUTE_FLOAT16 scale,offset;\n"
" {\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+(k2/blockDim)*dstChannelC4*8))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset.s02468ace,scaleOffset.s02468ace);\n"
" offset=(COMPUTE_FLOAT16)(scaleOffset.s13579bdf,scaleOffset.s13579bdf);\n"
" #else\n"
" COMPUTE_FLOAT8 scaleOffset=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+(k2/blockDim)*dstChannelC4*4))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset,scaleOffset);\n"
" offset=0;\n"
" #endif\n"
" }\n"
" COMPUTE_FLOAT2 in=CONVERT_COMPUTE_FLOAT2(vload2(0,input+k2));\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(j,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(j,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s0,wei.s01234567,out0);\n"
" }\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s1,wei.s89abcdef,out0);\n"
" }\n"
" }\n"
"#if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" {\n"
" int k2=loop << 1;\n"
" COMPUTE_FLOAT16 scale,offset;\n"
" {\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+(k2/blockDim)*dstChannelC4*8))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset.s02468ace,scaleOffset.s02468ace);\n"
" offset=(COMPUTE_FLOAT16)(scaleOffset.s13579bdf,scaleOffset.s13579bdf);\n"
" #else\n"
" COMPUTE_FLOAT8 scaleOffset=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+(k2/blockDim)*dstChannelC4*4))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset,scaleOffset);\n"
" offset=0;\n"
" #endif\n"
" }\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(loop,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(j,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)input[k2],wei.s01234567,out0);\n"
" }\n"
" }\n"
"#endif\n"
" sum[lid]=out0;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" for(int i=WGS/2; i>0; i /= 2){\n"
" if (lid<i)\n"
" sum[lid]=sum[lid]+sum[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" if(lid == 0){\n"
" out0=sum[0]+CONVERT_COMPUTE_FLOAT8(vload8(0,bias+oc8));\n"
" #ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT8)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
" #endif\n"
" #ifdef OUTPUT_CHANNEL_LEAVES\n"
" vstore4(CONVERT_FLOAT4(out0.s0123),0,output+oc8);\n"
" if(oc8+4<dstChannelC4*4)\n"
" vstore4(CONVERT_FLOAT4(out0.s4567),0,output+oc8+4);\n"
" #else\n"
" vstore8(CONVERT_FLOAT8(out0),0,output+oc8);\n"
" #endif\n"
" }\n"
"}\n"
"#else\n"
"__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
" __global const uchar *weight,\n"
" #endif\n"
"#endif\n"
" __global const FLOAT *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelAlign,\n"
" __private const int srcChannelAlign,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
@ -266,29 +309,82 @@ const char* gemv_conv1x1_buf =
" \n"
" UNIFORM_BOUNDRY_CHECK_2(ic,oc);\n"
" const int oc8=oc << 3;\n"
" \n"
" const int loop=(blockDim+4-1)/4;\n"
"#if INPUT_CHANNEL_LEAVES_NUM != 0\n"
"#if QUANT_BIT == 8\n"
" const int loop=(blockDim+2-1)/2;\n"
" #if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop_end=max(loop-1,0);\n"
"#else\n"
" #else\n"
" const int loop_end=loop;\n"
" #endif\n"
" #ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*32;\n"
" #endif\n"
"#else\n"
" const int loop=(blockDim+4-1)/4;\n"
" #if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop_end=max(loop-1,0);\n"
" #else\n"
" const int loop_end=loop;\n"
" #endif\n"
" #ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*16;\n"
" #endif\n"
"#endif\n"
" COMPUTE_FLOAT8 out0=CONVERT_COMPUTE_FLOAT8(vload8(0,bias+oc8));\n"
"#ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*16;\n"
"#endif\n"
" for (int i=0; i<blockNum; i++){\n"
"#ifdef ASYMMETRIC\n"
" #if QUANT_BIT == 8\n"
" COMPUTE_FLOAT16 scale,offset;\n"
" {\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+i*dstChannelC4*8))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset.s02468ace,scaleOffset.s02468ace);\n"
" offset=(COMPUTE_FLOAT16)(scaleOffset.s13579bdf,scaleOffset.s13579bdf);\n"
" #else\n"
" COMPUTE_FLOAT8 scaleOffset=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+i*dstChannelC4*4))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset,scaleOffset);\n"
" offset=0;\n"
" #endif\n"
" }\n"
" for (int j=0; j<loop_end; j++) {\n"
" int k=i*loop+j;\n"
" COMPUTE_FLOAT2 in=CONVERT_COMPUTE_FLOAT2(vload2(0,input+(k << 1)));\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(k,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(k,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s0,wei.s01234567,out0);\n"
" }\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s1,wei.s89abcdef,out0);\n"
" }\n"
" }\n"
" #if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" {\n"
" int k=i*loop+loop_end;\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(k,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(k,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)input[k << 1],wei.s01234567,out0);\n"
" }\n"
" }\n"
" #endif\n"
" #else\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT8 scale,offset;\n"
" {\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+i*dstChannelC4*8))/coef);\n"
" scale=scaleOffset.s02468ace;\n"
" offset=scaleOffset.s13579bdf;\n"
" }\n"
"#else\n"
" #else\n"
" COMPUTE_FLOAT8 scale=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+i*dstChannelC4*4))/coef);\n"
" COMPUTE_FLOAT8 offset=0;\n"
"#endif\n"
" #endif\n"
" for (int j=0; j<loop_end; j++) {\n"
" int k=i*loop+j;\n"
" COMPUTE_FLOAT8 wei;\n"
@ -343,91 +439,6 @@ const char* gemv_conv1x1_buf =
" #endif\n"
" }\n"
" #endif\n"
"}\n"
"#ifdef RELU\n"
" out0=fmax(out0,(COMPUTE_FLOAT8)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(COMPUTE_FLOAT8)0,(COMPUTE_FLOAT8)6);\n"
"#endif\n"
" #ifdef OUTPUT_CHANNEL_LEAVES\n"
" vstore4(CONVERT_FLOAT4(out0.s0123),0,output+oc8);\n"
" if(oc8+4<dstChannelC4*4)\n"
" vstore4(CONVERT_FLOAT4(out0.s4567),0,output+oc8+4);\n"
" #else\n"
" vstore8(CONVERT_FLOAT8(out0),0,output+oc8);\n"
" #endif\n"
"}\n"
"__kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2\n"
" __global const FLOAT* input,\n"
"#ifdef USE_IMAGE\n"
" __read_only image2d_t weight,\n"
"#else\n"
" __global const char *weight,\n"
"#endif\n"
" __global const FLOAT *dequantScaleOffset,\n"
" __global const FLOAT *bias,\n"
" __global FLOAT* output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int srcChannel,\n"
" __private const int blockNum,\n"
" __private const int blockDim,\n"
" __private const float coef) {\n"
" const int ic=get_global_id(0);\n"
" const int oc=get_global_id(1); //oc/8\n"
" UNIFORM_BOUNDRY_CHECK_2(ic,oc);\n"
" const int oc8=oc << 3;\n"
" const int loop=(blockDim+2-1)/2;\n"
"#if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" const int loop_end=max(loop-1,0);\n"
"#else\n"
" const int loop_end=loop;\n"
"#endif\n"
"#ifndef USE_IMAGE\n"
" const int weight_offset=oc*srcChannelC4*32;\n"
"#endif\n"
" COMPUTE_FLOAT8 out0=CONVERT_COMPUTE_FLOAT8(vload8(0,bias+oc8));\n"
" for (int i=0; i<blockNum; i++){\n"
" COMPUTE_FLOAT16 scale,offset;\n"
" {\n"
" #ifdef ASYMMETRIC\n"
" COMPUTE_FLOAT16 scaleOffset=CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0,dequantScaleOffset+oc8*2+i*dstChannelC4*8))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset.s02468ace,scaleOffset.s02468ace);\n"
" offset=(COMPUTE_FLOAT16)(scaleOffset.s13579bdf,scaleOffset.s13579bdf);\n"
" #else\n"
" COMPUTE_FLOAT8 scaleOffset=CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0,dequantScaleOffset+oc8+i*dstChannelC4*4))/coef);\n"
" scale=(COMPUTE_FLOAT16)(scaleOffset,scaleOffset);\n"
" offset=0;\n"
" #endif\n"
" }\n"
" for (int j=0; j<loop_end; j++) {\n"
" int k=i*loop+j;\n"
" COMPUTE_FLOAT2 in=CONVERT_COMPUTE_FLOAT2(vload2(0,input+(k << 1)));\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(k,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(k,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s0,wei.s01234567,out0);\n"
" }\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)in.s1,wei.s89abcdef,out0);\n"
" }\n"
" }\n"
" #if INPUT_CHANNEL_LEAVES_NUM != 0\n"
" {\n"
" int k=i*loop+loop_end;\n"
" #ifdef USE_IMAGE\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(k,oc))))*scale+offset;\n"
" #else\n"
" COMPUTE_FLOAT16 wei=CONVERT_COMPUTE_FLOAT16(vload16(k,weight+weight_offset))*scale+offset;\n"
" #endif\n"
" {\n"
" out0=mad((COMPUTE_FLOAT8)input[k << 1],wei.s01234567,out0);\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
"#ifdef RELU\n"

View File

@ -139,6 +139,8 @@ __kernel void gl_to_cl(GLOBAL_SIZE_3_DIMS
vstore4(in2, 0, output_ptr + output_offset + 8);
if(w + 3 >= shape.w) return;
vstore4(in3, 0, output_ptr + output_offset + 12);
#else
//not support
#endif
#endif
}
@ -167,11 +169,12 @@ __kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS
int idx = c * shape.w + w; // c/4*w
int idy = nh; // n*h
INPUT_TYPE4 in0, in1, in2, in3;
#ifdef USE_IMAGE
INPUT_TYPE4 in0 = RI_DATA(input_ptr, SAMPLER, (int2)(idx, idy));
INPUT_TYPE4 in1 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+1, idy));
INPUT_TYPE4 in2 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+2, idy));
INPUT_TYPE4 in3 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+3, idy));
in0 = RI_DATA(input_ptr, SAMPLER, (int2)(idx, idy));
in1 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+1, idy));
in2 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+2, idy));
in3 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+3, idy));
#else
#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW
int input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w;
@ -181,22 +184,24 @@ __kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS
tmp1 = vload4(0, input_ptr + input_offset + stride);
tmp2 = vload4(0, input_ptr + input_offset + stride + stride);
tmp3 = vload4(0, input_ptr + input_offset + stride + stride + stride);
INPUT_TYPE4 in0 = (INPUT_TYPE4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x);
INPUT_TYPE4 in1 = (INPUT_TYPE4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y);
INPUT_TYPE4 in2 = (INPUT_TYPE4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z);
INPUT_TYPE4 in3 = (INPUT_TYPE4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w);
in0 = (INPUT_TYPE4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x);
in1 = (INPUT_TYPE4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y);
in2 = (INPUT_TYPE4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z);
in3 = (INPUT_TYPE4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w);
#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC
int input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c;
INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset);
INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + shape.y);
INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + shape.y + shape.y);
INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + shape.y + shape.y + shape.y);
in0 = vload4(0, input_ptr + input_offset);
in1 = vload4(0, input_ptr + input_offset + shape.y);
in2 = vload4(0, input_ptr + input_offset + shape.y + shape.y);
in3 = vload4(0, input_ptr + input_offset + shape.y + shape.y + shape.y);
#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
int input_offset = (((cblock * shape.x + n) * shape.z + h) * shape.w + w) * 4;
INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset);
INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + 4);
INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + 8);
INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + 12);
in0 = vload4(0, input_ptr + input_offset);
in1 = vload4(0, input_ptr + input_offset + 4);
in2 = vload4(0, input_ptr + input_offset + 8);
in3 = vload4(0, input_ptr + input_offset + 12);
#else
// not support
#endif
#endif
const int offset = idy * shape.w * 4;

View File

@ -132,6 +132,8 @@ const char* glmem_convert =
" vstore4(in2,0,output_ptr+output_offset+8);\n"
" if(w+3 >= shape.w) return;\n"
" vstore4(in3,0,output_ptr+output_offset+12);\n"
" #else\n"
" //not support\n"
" #endif\n"
"#endif\n"
"}\n"
@ -157,11 +159,12 @@ const char* glmem_convert =
" \n"
" int idx=c*shape.w+w; // c/4*w\n"
" int idy=nh; // n*h\n"
" INPUT_TYPE4 in0,in1,in2,in3;\n"
"#ifdef USE_IMAGE\n"
" INPUT_TYPE4 in0=RI_DATA(input_ptr,SAMPLER,(int2)(idx,idy));\n"
" INPUT_TYPE4 in1=RI_DATA(input_ptr,SAMPLER,(int2)(idx+1,idy));\n"
" INPUT_TYPE4 in2=RI_DATA(input_ptr,SAMPLER,(int2)(idx+2,idy));\n"
" INPUT_TYPE4 in3=RI_DATA(input_ptr,SAMPLER,(int2)(idx+3,idy));\n"
" in0=RI_DATA(input_ptr,SAMPLER,(int2)(idx,idy));\n"
" in1=RI_DATA(input_ptr,SAMPLER,(int2)(idx+1,idy));\n"
" in2=RI_DATA(input_ptr,SAMPLER,(int2)(idx+2,idy));\n"
" in3=RI_DATA(input_ptr,SAMPLER,(int2)(idx+3,idy));\n"
"#else\n"
" #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n"
@ -171,22 +174,24 @@ const char* glmem_convert =
" tmp1=vload4(0,input_ptr+input_offset+stride);\n"
" tmp2=vload4(0,input_ptr+input_offset+stride+stride);\n"
" tmp3=vload4(0,input_ptr+input_offset+stride+stride+stride);\n"
" INPUT_TYPE4 in0=(INPUT_TYPE4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
" INPUT_TYPE4 in1=(INPUT_TYPE4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
" INPUT_TYPE4 in2=(INPUT_TYPE4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
" INPUT_TYPE4 in3=(INPUT_TYPE4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
" in0=(INPUT_TYPE4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
" in1=(INPUT_TYPE4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
" in2=(INPUT_TYPE4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
" in3=(INPUT_TYPE4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
" #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n"
" INPUT_TYPE4 in0=vload4(0,input_ptr+input_offset);\n"
" INPUT_TYPE4 in1=vload4(0,input_ptr+input_offset+shape.y);\n"
" INPUT_TYPE4 in2=vload4(0,input_ptr+input_offset+shape.y+shape.y);\n"
" INPUT_TYPE4 in3=vload4(0,input_ptr+input_offset+shape.y+shape.y+shape.y);\n"
" in0=vload4(0,input_ptr+input_offset);\n"
" in1=vload4(0,input_ptr+input_offset+shape.y);\n"
" in2=vload4(0,input_ptr+input_offset+shape.y+shape.y);\n"
" in3=vload4(0,input_ptr+input_offset+shape.y+shape.y+shape.y);\n"
" #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int input_offset=(((cblock*shape.x+n)*shape.z+h)*shape.w+w)*4;\n"
" INPUT_TYPE4 in0=vload4(0,input_ptr+input_offset);\n"
" INPUT_TYPE4 in1=vload4(0,input_ptr+input_offset+4);\n"
" INPUT_TYPE4 in2=vload4(0,input_ptr+input_offset+8);\n"
" INPUT_TYPE4 in3=vload4(0,input_ptr+input_offset+12);\n"
" in0=vload4(0,input_ptr+input_offset);\n"
" in1=vload4(0,input_ptr+input_offset+4);\n"
" in2=vload4(0,input_ptr+input_offset+8);\n"
" in3=vload4(0,input_ptr+input_offset+12);\n"
" #else\n"
" // not support\n"
" #endif\n"
"#endif\n"
" const int offset=idy*shape.w*4;\n"

View File

@ -2,6 +2,7 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
#if LOCAL_SIZE > 1
__kernel void groupnorm_plain_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2,
#ifdef DOUBLE_INPUTS
__global const FLOAT * input0,
@ -242,3 +243,4 @@ __kernel void groupnorm_plain_buf(__private int global_dim0, __private int globa
#endif
}
}
#endif

View File

@ -5,6 +5,7 @@ const char* groupnorm_buf =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#if LOCAL_SIZE>1\n"
"__kernel void groupnorm_plain_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
"#ifdef DOUBLE_INPUTS\n"
" __global const FLOAT*input0,\n"
@ -226,6 +227,7 @@ const char* groupnorm_buf =
"#endif\n"
" }\n"
"}\n"
"#endif\n"
;
#endif
}

View File

@ -49,8 +49,9 @@ __kernel void conv_transe_c4_c1(
w * output_x_pitch;
FLOAT4 value = vload4(0, input + input_offset);
FLOAT *value_ptr = (FLOAT*)&value;
for(int i = 0; i < 4 && cout + i < input_channel; ++i){
output[output_offset + i * output_f_pitch] = value[i];
output[output_offset + i * output_f_pitch] = value_ptr[i];
}
}

View File

@ -48,8 +48,9 @@ const char* input_transe_buf =
" w*output_x_pitch;\n"
" \n"
" FLOAT4 value=vload4(0,input+input_offset);\n"
" FLOAT *value_ptr=(FLOAT*)&value;\n"
" for(int i=0; i<4 && cout+i<input_channel; ++i){\n"
" output[output_offset+i*output_f_pitch]=value[i];\n"
" output[output_offset+i*output_f_pitch]=value_ptr[i];\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"

View File

@ -3,7 +3,7 @@
#endif
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef LOCAL_SIZE
__kernel void layernorm_w(__private int global_dim0, __private int global_dim1, __private int global_dim2,
__read_only image2d_t input,
__write_only image2d_t output,
@ -249,3 +249,4 @@ __kernel void layernorm_chw(__private int global_dim0, __private int global_dim1
}
}
}
#endif

View File

@ -41,7 +41,7 @@ __kernel void layernorm_buf(__private int global_dim0, __private int global_dim1
#ifdef PACK_LEAVE
if(index == inside_v4 - 1) {
for(int i = 0; i < inside_remain; ++i)
for(int i = 0; i < inside_remain; ++i){
float in = input[offset + index * 4 + i];
sum_mean_mnn[lid] = sum_mean_mnn[lid] + in;
}
@ -67,9 +67,9 @@ __kernel void layernorm_buf(__private int global_dim0, __private int global_dim1
sum_mnn[lid] = in_sum.x + in_sum.y + in_sum.z + in_sum.w;
#ifdef PACK_LEAVE
if(index == inside_v4 - 1) {
for(int i = 0; i < inside_remain; ++i)
for(int i = 0; i < inside_remain; ++i) {
float in = input[offset + index * 4 + i];
in = (in - mean) * (in - mean);
in = (in - mean.x) * (in - mean.x);
sum_mnn[lid] = sum_mnn[lid] + in;
}
}

View File

@ -44,7 +44,7 @@ const char* layernorm_buf =
" \n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i)\n"
" for(int i=0; i<inside_remain; ++i){\n"
" float in=input[offset+index*4+i];\n"
" sum_mean_mnn[lid]=sum_mean_mnn[lid]+in;\n"
" }\n"
@ -69,9 +69,9 @@ const char* layernorm_buf =
" sum_mnn[lid]=in_sum.x+in_sum.y+in_sum.z+in_sum.w;\n"
" #ifdef PACK_LEAVE\n"
" if(index == inside_v4-1) {\n"
" for(int i=0; i<inside_remain; ++i)\n"
" for(int i=0; i<inside_remain; ++i) {\n"
" float in=input[offset+index*4+i];\n"
" in=(in-mean)*(in-mean);\n"
" in=(in-mean.x)*(in-mean.x);\n"
" sum_mnn[lid]=sum_mnn[lid]+in;\n"
" }\n"
" }\n"

View File

@ -5,6 +5,7 @@ const char* layernorm =
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#ifdef LOCAL_SIZE\n"
"__kernel void layernorm_w(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __read_only image2d_t input,\n"
" __write_only image2d_t output,\n"
@ -245,5 +246,6 @@ const char* layernorm =
" }\n"
" }\n"
"}\n"
"#endif\n"
;
}

View File

@ -40,13 +40,13 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1,
#endif
int4 offset = index * steps + offsets;
#if TRANSPOSE_A
#ifdef TRANSPOSE_A
__global FLOAT* A_ptr = input_A + offset.y + pos.y;
#else
__global FLOAT* A_ptr = input_A + offset.y + pos.y * l;
#endif
#if TRANSPOSE_B
#ifdef TRANSPOSE_B
__global FLOAT* B_ptr = input_B + offset.z + pos.x * l;
#else
__global FLOAT* B_ptr = input_B + offset.z + pos.x;
@ -68,7 +68,7 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1,
for(int i = 0; i < l_pack - 1; ++i){
int l_offset = i << 2;
FLOAT4 value_a0, value_a1, value_a2, value_a3, value_b0, value_b1, value_b2, value_b3;
#if TRANSPOSE_A
#ifdef TRANSPOSE_A
value_a0 = vload4(0, A_ptr + l_offset * e);
value_a1 = vload4(0, A_ptr + (l_offset + 1) * e);
value_a2 = vload4(0, A_ptr + (l_offset + 2) * e);
@ -80,7 +80,7 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1,
value_a3 = vload4(0, A_ptr + l_offset + 3 * l);
#endif
#if TRANSPOSE_B
#ifdef TRANSPOSE_B
FLOAT4 value_tmp0 = vload4(0, B_ptr + l_offset);
FLOAT4 value_tmp1 = vload4(0, B_ptr + l_offset + l);
FLOAT4 value_tmp2 = vload4(0, B_ptr + l_offset + 2 * l);
@ -140,7 +140,7 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1,
}
for(int i = ((l_pack - 1) << 2); i < l; ++i){
#if TRANSPOSE_A
#ifdef TRANSPOSE_A
FLOAT4 value_a = vload4(0, A_ptr + i * e);
#else
FLOAT4 value_a;
@ -150,7 +150,7 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1,
value_a.w = A_ptr[i + 3 * l];
#endif
#if TRANSPOSE_B
#ifdef TRANSPOSE_B
FLOAT4 value_b;
value_b.x = B_ptr[i];
value_b.y = B_ptr[i + l];
@ -323,7 +323,9 @@ __kernel void batch_gather(__private int global_dim0, __private int global_dim1,
}
}
#ifdef LOOP_BINARY_OPERATOR
#ifndef OPERATOR
#define OPERATOR in0 + in1
#endif
__kernel void broadcast_binary(__private int global_dim0, __private int global_dim1, __private int global_dim2,
__write_only image2d_t output, __read_only image2d_t input0, __read_only image2d_t input1,
__private const int8 src0_size, //(batch, channel, height, width)
@ -422,13 +424,13 @@ __kernel void broadcast_binary(__private int global_dim0, __private int global_d
int4 out = in0 % in1;
out = ((out < (int4)0 && in1 > (int4)0) || (out > (int4)0 && in1 < (int4)0)) ? out + in1 : out;
#else
float4 out = LOOP_BINARY_OPERATOR;
float4 out = OPERATOR;
#endif
WI_DATA(output, (int2)(co * dst_width + wo, no * dst_height + ho), CONVERT_OUTPUT_I4(out));
}
}
#ifdef COMPUTE_CUMSUM
__kernel void loop_cumsum(__private int global_dim0, __private int global_dim1, __private int global_dim2,
__global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1,
__private const int input0Stride0,
@ -455,21 +457,20 @@ __kernel void loop_cumsum(__private int global_dim0, __private int global_dim1,
int inputIndex1 = z * input1Stride0 + y * input1Stride1 + x * input1Stride2;
int outputIndex = z * outputStride0 + y * outputStride1 + x * outputStride2;
float in0 = 0;
float4 in0 = 0;
if(offsets.z != offsets.y){
in0 = (float)input0[inputIndex0];
in0.x = (float)input0[inputIndex0];
}
for(int i = 0; i < loopNumber; ++i){
int4 offset = (int4)i * steps + offsets;
float in1 = (float)input1[inputIndex1 + offset.z];
float out = LOOP_BINARY_OPERATOR;
float4 in1;
in1.x = (float)input1[inputIndex1 + offset.z];
float4 out = OPERATOR;
output[outputIndex + offset.x] = (OUTPUT_TYPE)out;
in0 = out;
output[outputIndex + offset.x] = (OUTPUT_TYPE)out.x;
in0.x = out.x;
}
}
}
#endif
#endif

View File

@ -387,7 +387,9 @@ __kernel void pack_buf(__private int global_dim0, __private int global_dim1, __p
}
}
#ifdef LOOP_BINARY_OPERATOR
#ifndef OPERATOR
#define OPERATOR in0 + in1
#endif
__kernel void loop_binary_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2,
__global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1,
__private const int input0Stride0,
@ -414,11 +416,11 @@ __kernel void loop_binary_buf(__private int global_dim0, __private int global_di
int in0 = (int)input0[inputIndex0];
int in1 = (int)input1[inputIndex1];
int out = in0 % in1;
out = ((out < (int4)0 && in1 > (int4)0) || (out > (int4)0 && in1 < (int4)0)) ? out + in1 : out;
out = ((out < 0 && in1 > 0) || (out > 0 && in1 < 0)) ? out + in1 : out;
#else
float in0 = (float)input0[inputIndex0];
float in1 = (float)input1[inputIndex1];
float out = LOOP_BINARY_OPERATOR;
float out = OPERATOR;
#endif
output[outputIndex] = (OUTPUT_TYPE)out;
}
@ -458,11 +460,10 @@ __kernel void loop_cumsum_buf(__private int global_dim0, __private int global_di
for(int i = 0; i < loopNumber; ++i){
int4 offset = (int4)i * steps + offsets;
float in1 = (float)input1[inputIndex1 + offset.z];
float out = LOOP_BINARY_OPERATOR;
float out = OPERATOR;
output[outputIndex + offset.x] = (OUTPUT_TYPE)out;
in0 = out;
}
}
}
#endif

View File

@ -379,7 +379,9 @@ const char* loop_buf =
" vstore4(value,0,output+b*b_dst_pitch+c_4*c_dst_pitch+h*y_dst_pitch+w*x_dst_pitch);\n"
" }\n"
"}\n"
"#ifdef LOOP_BINARY_OPERATOR\n"
"#ifndef OPERATOR\n"
" #define OPERATOR in0+in1\n"
"#endif\n"
"__kernel void loop_binary_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n"
" __private const int input0Stride0,\n"
@ -406,11 +408,11 @@ const char* loop_buf =
" int in0=(int)input0[inputIndex0];\n"
" int in1=(int)input1[inputIndex1];\n"
" int out=in0 % in1;\n"
" out=((out<(int4)0 && in1>(int4)0) || (out>(int4)0 && in1<(int4)0)) ? out+in1 : out;\n"
" out=((out<0 && in1>0) || (out>0 && in1<0)) ? out+in1 : out;\n"
" #else\n"
" float in0=(float)input0[inputIndex0];\n"
" float in1=(float)input1[inputIndex1];\n"
" float out=LOOP_BINARY_OPERATOR;\n"
" float out=OPERATOR;\n"
" #endif\n"
" output[outputIndex]=(OUTPUT_TYPE)out;\n"
" }\n"
@ -449,14 +451,13 @@ const char* loop_buf =
" for(int i=0; i<loopNumber; ++i){\n"
" int4 offset=(int4)i*steps+offsets;\n"
" float in1=(float)input1[inputIndex1+offset.z];\n"
" float out=LOOP_BINARY_OPERATOR;\n"
" float out=OPERATOR;\n"
" \n"
" output[outputIndex+offset.x]=(OUTPUT_TYPE)out;\n"
" in0=out;\n"
" }\n"
" }\n"
"}\n"
"#endif\n"
;
#endif
}

View File

@ -42,12 +42,12 @@ const char* loop =
"#endif\n"
" int4 offset=index*steps+offsets;\n"
" \n"
"#if TRANSPOSE_A\n"
"#ifdef TRANSPOSE_A\n"
" __global FLOAT* A_ptr=input_A+offset.y+pos.y;\n"
"#else\n"
" __global FLOAT* A_ptr=input_A+offset.y+pos.y*l;\n"
"#endif\n"
"#if TRANSPOSE_B\n"
"#ifdef TRANSPOSE_B\n"
" __global FLOAT* B_ptr=input_B+offset.z+pos.x*l;\n"
"#else\n"
" __global FLOAT* B_ptr=input_B+offset.z+pos.x;\n"
@ -67,7 +67,7 @@ const char* loop =
" for(int i=0; i<l_pack-1; ++i){\n"
" int l_offset=i << 2;\n"
" FLOAT4 value_a0,value_a1,value_a2,value_a3,value_b0,value_b1,value_b2,value_b3;\n"
"#if TRANSPOSE_A\n"
"#ifdef TRANSPOSE_A\n"
" value_a0=vload4(0,A_ptr+l_offset*e);\n"
" value_a1=vload4(0,A_ptr+(l_offset+1)*e);\n"
" value_a2=vload4(0,A_ptr+(l_offset+2)*e);\n"
@ -78,7 +78,7 @@ const char* loop =
" value_a2=vload4(0,A_ptr+l_offset+2*l);\n"
" value_a3=vload4(0,A_ptr+l_offset+3*l);\n"
"#endif\n"
"#if TRANSPOSE_B\n"
"#ifdef TRANSPOSE_B\n"
" FLOAT4 value_tmp0=vload4(0,B_ptr+l_offset);\n"
" FLOAT4 value_tmp1=vload4(0,B_ptr+l_offset+l);\n"
" FLOAT4 value_tmp2=vload4(0,B_ptr+l_offset+2*l);\n"
@ -136,7 +136,7 @@ const char* loop =
"#endif\n"
" }\n"
" for(int i=((l_pack-1) << 2); i<l; ++i){\n"
"#if TRANSPOSE_A\n"
"#ifdef TRANSPOSE_A\n"
" FLOAT4 value_a=vload4(0,A_ptr+i*e);\n"
"#else\n"
" FLOAT4 value_a;\n"
@ -145,7 +145,7 @@ const char* loop =
" value_a.z=A_ptr[i+2*l];\n"
" value_a.w=A_ptr[i+3*l];\n"
"#endif\n"
"#if TRANSPOSE_B\n"
"#ifdef TRANSPOSE_B\n"
" FLOAT4 value_b;\n"
" value_b.x=B_ptr[i];\n"
" value_b.y=B_ptr[i+l];\n"
@ -310,7 +310,9 @@ const char* loop =
" }\n"
" }\n"
"}\n"
"#ifdef LOOP_BINARY_OPERATOR\n"
"#ifndef OPERATOR\n"
" #define OPERATOR in0+in1\n"
"#endif\n"
"__kernel void broadcast_binary(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __write_only image2d_t output,__read_only image2d_t input0,__read_only image2d_t input1,\n"
" __private const int8 src0_size,//(batch,channel,height,width)\n"
@ -409,13 +411,12 @@ const char* loop =
" int4 out=in0 % in1;\n"
" out=((out<(int4)0 && in1>(int4)0) || (out>(int4)0 && in1<(int4)0)) ? out+in1 : out;\n"
" #else\n"
" float4 out=LOOP_BINARY_OPERATOR;\n"
" float4 out=OPERATOR;\n"
" #endif\n"
" \n"
" WI_DATA(output,(int2)(co*dst_width+wo,no*dst_height+ho),CONVERT_OUTPUT_I4(out));\n"
" }\n"
"}\n"
"#ifdef COMPUTE_CUMSUM\n"
"__kernel void loop_cumsum(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n"
" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n"
" __private const int input0Stride0,\n"
@ -442,22 +443,21 @@ const char* loop =
" int inputIndex1=z*input1Stride0+y*input1Stride1+x*input1Stride2;\n"
" int outputIndex=z*outputStride0+y*outputStride1+x*outputStride2;\n"
" \n"
" float in0=0;\n"
" float4 in0=0;\n"
" if(offsets.z != offsets.y){\n"
" in0=(float)input0[inputIndex0];\n"
" in0.x=(float)input0[inputIndex0];\n"
" }\n"
" \n"
" for(int i=0; i<loopNumber; ++i){\n"
" int4 offset=(int4)i*steps+offsets;\n"
" float in1=(float)input1[inputIndex1+offset.z];\n"
" float out=LOOP_BINARY_OPERATOR;\n"
" float4 in1;\n"
" in1.x=(float)input1[inputIndex1+offset.z];\n"
" float4 out=OPERATOR;\n"
" \n"
" output[outputIndex+offset.x]=(OUTPUT_TYPE)out;\n"
" in0=out;\n"
" output[outputIndex+offset.x]=(OUTPUT_TYPE)out.x;\n"
" in0.x=out.x;\n"
" }\n"
" }\n"
"}\n"
"#endif\n"
"#endif\n"
;
}

View File

@ -687,10 +687,9 @@ INLINE_FUNC void StoreResultsM(__global realM* cgm, COMPUTE_FLOATM c_value, cons
Multiply(result.sE, alpha, xval.sE);
Multiply(result.sF, alpha, xval.sF);
#endif
#endif
// The final multiplication with alpha and the addition with beta*C
#ifdef HAVE_ALPHA_BETA
#elif HAVE_ALPHA_BETA
COMPUTE_FLOATM xval = c_value;
COMPUTE_FLOATM yval = CONVERT_COMPUTE_FLOATM(cgm[index]);
#if VWM == 1
@ -821,10 +820,9 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, COMPUTE_FLOATN c_value,
Multiply(result.sE, alpha, xval.sE);
Multiply(result.sF, alpha, xval.sF);
#endif
#endif
// The final multiplication with alpha and the addition with beta*C
#ifdef HAVE_ALPHA_BETA
#elif HAVE_ALPHA_BETA
COMPUTE_FLOATN xval = c_value;
COMPUTE_FLOATN yval = CONVERT_COMPUTE_FLOATN(cgn[index]);
#if VWN == 1

View File

@ -631,9 +631,8 @@ const char* matmul_params_buf =
" Multiply(result.sE,alpha,xval.sE);\n"
" Multiply(result.sF,alpha,xval.sF);\n"
" #endif\n"
" #endif\n"
" // The final multiplication with alpha and the addition with beta*C\n"
" #ifdef HAVE_ALPHA_BETA\n"
" #elif HAVE_ALPHA_BETA\n"
" COMPUTE_FLOATM xval=c_value;\n"
" COMPUTE_FLOATM yval=CONVERT_COMPUTE_FLOATM(cgm[index]);\n"
" #if VWM == 1\n"
@ -761,9 +760,8 @@ const char* matmul_params_buf =
" Multiply(result.sE,alpha,xval.sE);\n"
" Multiply(result.sF,alpha,xval.sF);\n"
" #endif\n"
" #endif\n"
" // The final multiplication with alpha and the addition with beta*C\n"
" #ifdef HAVE_ALPHA_BETA\n"
" #elif HAVE_ALPHA_BETA\n"
" COMPUTE_FLOATN xval=c_value;\n"
" COMPUTE_FLOATN yval=CONVERT_COMPUTE_FLOATN(cgn[index]);\n"
" #if VWN == 1\n"

View File

@ -16,7 +16,6 @@ extern const char* range_buf;
#ifndef MNN_OPENCL_BUFFER_CLOSED
extern const char* self_attention_buf;
#endif
extern const char* performance;
extern const char* winogradTransformSource2_3_1;
#ifndef MNN_OPENCL_BUFFER_CLOSED
extern const char* gemv_conv1x1_buf;
@ -90,6 +89,7 @@ extern const char* buffer_convert_quant;
#ifndef MNN_OPENCL_BUFFER_CLOSED
extern const char* gemm_buf;
#endif
extern const char* conv_2d_int;
extern const char* copy_buffer_to_image2d;
extern const char* loop;
#ifndef MNN_OPENCL_BUFFER_CLOSED
@ -124,6 +124,7 @@ extern const char* pooling;
#ifndef MNN_OPENCL_BUFFER_CLOSED
extern const char* conv_2d_buf;
#endif
extern const char* gemm_int;
extern const char* buffer_to_image;
extern const char* winogradTransformDest2_3_1;
#ifndef MNN_OPENCL_BUFFER_CLOSED
@ -188,7 +189,6 @@ const std::map<std::string, const char*> OpenCLProgramMap =
#ifndef MNN_OPENCL_BUFFER_CLOSED
{ "self_attention_buf", self_attention_buf },
#endif
{ "performance", performance },
{ "winogradTransformSource2_3_1", winogradTransformSource2_3_1 },
#ifndef MNN_OPENCL_BUFFER_CLOSED
{ "gemv_conv1x1_buf", gemv_conv1x1_buf },
@ -262,6 +262,7 @@ const std::map<std::string, const char*> OpenCLProgramMap =
#ifndef MNN_OPENCL_BUFFER_CLOSED
{ "gemm_buf", gemm_buf },
#endif
{ "conv_2d_int", conv_2d_int },
{ "copy_buffer_to_image2d", copy_buffer_to_image2d },
{ "loop", loop },
#ifndef MNN_OPENCL_BUFFER_CLOSED
@ -296,6 +297,7 @@ const std::map<std::string, const char*> OpenCLProgramMap =
#ifndef MNN_OPENCL_BUFFER_CLOSED
{ "conv_2d_buf", conv_2d_buf },
#endif
{ "gemm_int", gemm_int },
{ "buffer_to_image", buffer_to_image },
{ "winogradTransformDest2_3_1", winogradTransformDest2_3_1 },
#ifndef MNN_OPENCL_BUFFER_CLOSED

View File

@ -1,95 +0,0 @@
#define MAD_V4(x, y) \
x = mad(y, x, y); \
y = mad(x, y, x); \
x = mad(y, x, y); \
y = mad(x, y, x);
#define MAD_V16(x, y) \
MAD_V4(x, y); \
MAD_V4(x, y); \
MAD_V4(x, y); \
MAD_V4(x, y);
#define MAD_V64(x, y) \
MAD_V16(x, y); \
MAD_V16(x, y); \
MAD_V16(x, y); \
MAD_V16(x, y);
#define MAD_V128(x, y) \
MAD_V64(x, y); \
MAD_V64(x, y); \
MAD_V64(x, y); \
MAD_V64(x, y);
#define MAD_V256(x, y) \
MAD_V128(x, y); \
MAD_V128(x, y); \
MAD_V128(x, y); \
MAD_V128(x, y);
#ifdef MNN_SUPPORT_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
__kernel void float_precision(__global float* output_ptr, float mul_value) {
float mul_x = mul_value;
float mul_y = (float)get_local_id(0);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
output_ptr[get_global_id(0)] = mul_y;
}
__kernel void half4_precision(__global half* output_ptr, float mul_value) {
half mul = (half)mul_value;
half4 mul_x = (half4)(mul);
half4 mul_y = (half4)get_local_id(0);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
MAD_V256(mul_x, mul_y);
output_ptr[get_global_id(0)] = (mul_y.S0) + (mul_y.S1) + (mul_y.S2) + (mul_y.S3);
}

View File

@ -1,72 +0,0 @@
#include "opencl_source_map.hpp"
namespace MNN {
const char* performance =
"#define MAD_V4(x, y) "" x = mad(y, x, y); "" y = mad(x, y, x); "" x = mad(y, x, y); "" y=mad(x,y,x);\n"
"#define MAD_V16(x, y) "" MAD_V4(x, y); "" MAD_V4(x, y); "" MAD_V4(x, y); "" MAD_V4(x,y);\n"
"#define MAD_V64(x, y) "" MAD_V16(x, y); "" MAD_V16(x, y); "" MAD_V16(x, y); "" MAD_V16(x,y);\n"
"#define MAD_V128(x, y) "" MAD_V64(x, y); "" MAD_V64(x, y); "" MAD_V64(x, y); "" MAD_V64(x,y);\n"
"#define MAD_V256(x, y) "" MAD_V128(x, y); "" MAD_V128(x, y); "" MAD_V128(x, y); "" MAD_V128(x,y);\n"
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void float_precision(__global float* output_ptr,float mul_value) {\n"
" float mul_x=mul_value;\n"
" float mul_y=(float)get_local_id(0);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" output_ptr[get_global_id(0)]=mul_y;\n"
"}\n"
"__kernel void half4_precision(__global half* output_ptr,float mul_value) {\n"
" half mul=(half)mul_value;\n"
" half4 mul_x=(half4)(mul);\n"
" half4 mul_y=(half4)get_local_id(0);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" MAD_V256(mul_x,mul_y);\n"
" output_ptr[get_global_id(0)]=(mul_y.S0)+(mul_y.S1)+(mul_y.S2)+(mul_y.S3);\n"
"}\n"
;
}

View File

@ -29,6 +29,9 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
const int input_width_start = mad24(output_width_idx, stride_shape.y, -pad_shape.y);
const int input_channel_start = mul24(output_channel_idx, input_shape.y);
#ifdef RETURN_REDICE
int4 redice = (int4)0;
#endif
#ifdef POOL_AVG
FLOAT4 output_result = 0;
for (int height = 0; height < kernel_shape.x; height++) {
@ -58,9 +61,6 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
output_result = output_result * block_float_req;
#else
FLOAT4 output_result = (FLOAT4)(-FLT_MAX);
#if RETURN_REDICE
int4 redice = (int4)0;
#endif
for (int height = 0; height < kernel_shape.x; height++) {
int input_height_idx = input_height_start + height;
input_height_idx =
@ -73,7 +73,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
if (input_width_idx != -1) {
FLOAT4 input_data = RI_F(input, SAMPLER, (int2)(input_width_idx, input_height_idx));
#if RETURN_REDICE
#ifdef RETURN_REDICE
redice = input_data > output_result ? (int4)((input_height_start + height) * input_shape.y + input_width_start + width) : redice;
#endif
output_result = fmax(output_result, input_data);
@ -85,12 +85,12 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
const int output_channel_width_idx = mad24(output_channel_idx, output_width, output_width_idx);
WI_F(output, (int2)(output_channel_width_idx, output_batch_height_idx), output_result);
#if RETURN_REDICE
#ifdef RETURN_REDICE
WI_F(rediceOutput, (int2)(output_channel_width_idx, output_batch_height_idx), CONVERT_FLOAT4(redice));
#endif
}
#ifdef LOCAL_SIZE
#if LOCAL_SIZE > 1
__kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
__private const int2 input_shape, __private const int output_height, __private const int2 pad_shape,
__private const int2 stride_shape,
@ -105,10 +105,10 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
FLOAT4 output_result = 0;
#else
FLOAT4 output_result = (FLOAT4)(-FLT_MAX);
#if RETURN_REDICE
#endif
#ifdef RETURN_REDICE
int4 redice = (int4)0;
int4 local rediceId[LOCAL_SIZE];
#endif
#endif
FLOAT4 local sum_mnn[LOCAL_SIZE];
@ -122,14 +122,14 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
output_result += in;
#else
output_result = fmax(output_result, in);
#if RETURN_REDICE
#ifdef RETURN_REDICE
redice = in > output_result ? (int4)(i) : redice;
#endif
#endif
}
sum_mnn[local_id] = output_result;
#if RETURN_REDICE
#ifdef RETURN_REDICE
rediceId[local_id] = redice;
#endif
barrier(CLK_LOCAL_MEM_FENCE);
@ -139,7 +139,7 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
sum_mnn[local_id] = sum_mnn[local_id] + sum_mnn[local_id + i];
#else
{
#if RETURN_REDICE
#ifdef RETURN_REDICE
rediceId[local_id] = sum_mnn[local_id] > sum_mnn[local_id + i] ? rediceId[local_id] : rediceId[local_id + i];
#endif
sum_mnn[local_id] = fmax(sum_mnn[local_id], sum_mnn[local_id + i]);
@ -153,7 +153,7 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
#endif
WI_F(output, (int2)(output_channel_idx, output_batch_idx), output_result);
#if RETURN_REDICE
#ifdef RETURN_REDICE
redice = rediceId[0];
WI_F(rediceOutput, (int2)(output_channel_idx, output_batch_idx), CONVERT_FLOAT4(redice));
#endif

View File

@ -29,6 +29,9 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
const int iw_start = mad24(ow_idx, stride_shape.y, -pad_shape.y);
const int ih_start = mad24(oh_idx, stride_shape.x, -pad_shape.x);
#ifdef RETURN_REDICE
int4 redice = (int4)0;
#endif
#ifdef POOL_AVG
COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(0);
const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;
@ -57,9 +60,6 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
result = result / (COMPUTE_FLOAT4)(1.0*total_count);
#else
COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(-FLT_MAX);
#if RETURN_REDICE
int4 redice = (int4)0;
#endif
const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;
for(int kh=0; kh<kernel_shape.x; kh++) {
int ih_cur = ih_start + kh;
@ -72,7 +72,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
continue;
}
COMPUTE_FLOAT4 inp_data = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset+(kh*input_shape.y+kw)*4));
#if RETURN_REDICE
#ifdef RETURN_REDICE
redice = inp_data > result ? (int4)((ih_start + kh) * input_shape.y + iw_start + kw) : redice;
#endif
result = fmax(result, inp_data);
@ -82,7 +82,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
const int out_offset = (((b_idx + c_idx*batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4;
vstore4(CONVERT_FLOAT4(result), 0, output+out_offset);
#if RETURN_REDICE
#ifdef RETURN_REDICE
vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+out_offset);
#endif
}
@ -105,10 +105,10 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
COMPUTE_FLOAT4 output_result = 0;
#else
COMPUTE_FLOAT4 output_result = (COMPUTE_FLOAT4)(-FLT_MAX);
#if RETURN_REDICE
#endif
#ifdef RETURN_REDICE
int4 redice = (int4)0;
int4 local rediceId[LOCAL_SIZE];
#endif
#endif
COMPUTE_FLOAT4 local sum_mnn[LOCAL_SIZE];
@ -122,14 +122,14 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
output_result += in;
#else
output_result = fmax(output_result, in);
#if RETURN_REDICE
#ifdef RETURN_REDICE
redice = in > output_result ? (int4)(i) : redice;
#endif
#endif
}
sum_mnn[local_id] = output_result;
#if RETURN_REDICE
#ifdef RETURN_REDICE
rediceId[local_id] = redice;
#endif
barrier(CLK_LOCAL_MEM_FENCE);
@ -139,7 +139,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
sum_mnn[local_id] = sum_mnn[local_id] + sum_mnn[local_id + i];
#else
{
#if RETURN_REDICE
#ifdef RETURN_REDICE
rediceId[local_id] = sum_mnn[local_id] > sum_mnn[local_id + i] ? rediceId[local_id] : rediceId[local_id + i];
#endif
sum_mnn[local_id] = fmax(sum_mnn[local_id], sum_mnn[local_id + i]);
@ -154,7 +154,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,
const int out_offset = (output_batch_idx + output_channel_idx*batch)*4;
vstore4(CONVERT_FLOAT4(output_result), 0, output+out_offset);
#if RETURN_REDICE
#ifdef RETURN_REDICE
redice = rediceId[0];
vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+out_offset);
#endif

View File

@ -27,6 +27,9 @@ const char* pooling_buf =
" const int iw_start=mad24(ow_idx,stride_shape.y,-pad_shape.y);\n"
" const int ih_start=mad24(oh_idx,stride_shape.x,-pad_shape.x);\n"
" \n"
" #ifdef RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" #ifdef POOL_AVG\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(0);\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;\n"
@ -55,9 +58,6 @@ const char* pooling_buf =
" result=result/(COMPUTE_FLOAT4)(1.0*total_count);\n"
" #else\n"
" COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(-FLT_MAX);\n"
" #if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;\n"
" for(int kh=0; kh<kernel_shape.x; kh++) {\n"
" int ih_cur=ih_start+kh;\n"
@ -70,7 +70,7 @@ const char* pooling_buf =
" continue;\n"
" }\n"
" COMPUTE_FLOAT4 inp_data=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset+(kh*input_shape.y+kw)*4));\n"
" #if RETURN_REDICE\n"
" #ifdef RETURN_REDICE\n"
" redice=inp_data>result ? (int4)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n"
" #endif\n"
" result=fmax(result,inp_data);\n"
@ -80,7 +80,7 @@ const char* pooling_buf =
" \n"
" const int out_offset=(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n"
" vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n"
" #if RETURN_REDICE\n"
" #ifdef RETURN_REDICE\n"
" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n"
" #endif\n"
"}\n"
@ -101,11 +101,11 @@ const char* pooling_buf =
" COMPUTE_FLOAT4 output_result=0;\n"
"#else\n"
" COMPUTE_FLOAT4 output_result=(COMPUTE_FLOAT4)(-FLT_MAX);\n"
"#if RETURN_REDICE\n"
"#endif\n"
"#ifdef RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" int4 local rediceId[LOCAL_SIZE];\n"
"#endif\n"
"#endif\n"
" COMPUTE_FLOAT4 local sum_mnn[LOCAL_SIZE];\n"
" const int inp_offset=((output_batch_idx+output_channel_idx*batch)*input_shape.x)*input_shape.y*4;\n"
" const int size=input_shape.x*input_shape.y;\n"
@ -117,14 +117,14 @@ const char* pooling_buf =
" output_result += in;\n"
"#else\n"
" output_result=fmax(output_result,in);\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" redice=in>output_result ? (int4)(i) : redice;\n"
"#endif\n"
"#endif\n"
" }\n"
" \n"
" sum_mnn[local_id]=output_result;\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" rediceId[local_id]=redice;\n"
"#endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
@ -134,7 +134,7 @@ const char* pooling_buf =
" sum_mnn[local_id]=sum_mnn[local_id]+sum_mnn[local_id+i];\n"
"#else\n"
" {\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" rediceId[local_id]=sum_mnn[local_id]>sum_mnn[local_id+i] ? rediceId[local_id] : rediceId[local_id+i];\n"
"#endif\n"
" sum_mnn[local_id]=fmax(sum_mnn[local_id],sum_mnn[local_id+i]);\n"
@ -148,7 +148,7 @@ const char* pooling_buf =
"#endif\n"
" const int out_offset=(output_batch_idx+output_channel_idx*batch)*4;\n"
" vstore4(CONVERT_FLOAT4(output_result),0,output+out_offset);\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" redice=rediceId[0];\n"
" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n"
"#endif\n"

View File

@ -24,6 +24,9 @@ const char* pooling =
" const int input_height_start=mad24(output_height_idx,stride_shape.x,-pad_shape.x);\n"
" const int input_width_start=mad24(output_width_idx,stride_shape.y,-pad_shape.y);\n"
" const int input_channel_start=mul24(output_channel_idx,input_shape.y);\n"
" #ifdef RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
"#ifdef POOL_AVG\n"
" FLOAT4 output_result=0;\n"
" for (int height=0; height<kernel_shape.x; height++) {\n"
@ -51,9 +54,6 @@ const char* pooling =
" output_result=output_result*block_float_req;\n"
"#else\n"
" FLOAT4 output_result=(FLOAT4)(-FLT_MAX);\n"
" #if RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" #endif\n"
" for (int height=0; height<kernel_shape.x; height++) {\n"
" int input_height_idx=input_height_start+height;\n"
" input_height_idx =\n"
@ -65,7 +65,7 @@ const char* pooling =
" (input_width_idx<0 || input_width_idx >= input_shape.y));\n"
" if (input_width_idx != -1) {\n"
" FLOAT4 input_data=RI_F(input,SAMPLER,(int2)(input_width_idx,input_height_idx));\n"
" #if RETURN_REDICE\n"
" #ifdef RETURN_REDICE\n"
" redice=input_data>output_result ? (int4)((input_height_start+height)*input_shape.y+input_width_start+width) : redice;\n"
" #endif\n"
" output_result=fmax(output_result,input_data);\n"
@ -76,11 +76,11 @@ const char* pooling =
"#endif\n"
" const int output_channel_width_idx=mad24(output_channel_idx,output_width,output_width_idx);\n"
" WI_F(output,(int2)(output_channel_width_idx,output_batch_height_idx),output_result);\n"
" #if RETURN_REDICE\n"
" #ifdef RETURN_REDICE\n"
" WI_F(rediceOutput,(int2)(output_channel_width_idx,output_batch_height_idx),CONVERT_FLOAT4(redice));\n"
" #endif\n"
"}\n"
"#ifdef LOCAL_SIZE\n"
"#if LOCAL_SIZE>1\n"
"__kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n"
" __private const int2 input_shape,__private const int output_height,__private const int2 pad_shape,\n"
" __private const int2 stride_shape,\n"
@ -94,11 +94,11 @@ const char* pooling =
" FLOAT4 output_result=0;\n"
"#else\n"
" FLOAT4 output_result=(FLOAT4)(-FLT_MAX);\n"
"#if RETURN_REDICE\n"
"#endif\n"
"#ifdef RETURN_REDICE\n"
" int4 redice=(int4)0;\n"
" int4 local rediceId[LOCAL_SIZE];\n"
"#endif\n"
"#endif\n"
" FLOAT4 local sum_mnn[LOCAL_SIZE];\n"
" int wc=output_channel_idx*input_shape.y;\n"
" int bh=output_batch_idx*input_shape.x;\n"
@ -110,14 +110,14 @@ const char* pooling =
" output_result += in;\n"
"#else\n"
" output_result=fmax(output_result,in);\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" redice=in>output_result ? (int4)(i) : redice;\n"
"#endif\n"
"#endif\n"
" }\n"
" \n"
" sum_mnn[local_id]=output_result;\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" rediceId[local_id]=redice;\n"
"#endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
@ -127,7 +127,7 @@ const char* pooling =
" sum_mnn[local_id]=sum_mnn[local_id]+sum_mnn[local_id+i];\n"
"#else\n"
" {\n"
"#if RETURN_REDICE\n"
"#ifdef RETURN_REDICE\n"
" rediceId[local_id]=sum_mnn[local_id]>sum_mnn[local_id+i] ? rediceId[local_id] : rediceId[local_id+i];\n"
"#endif\n"
" sum_mnn[local_id]=fmax(sum_mnn[local_id],sum_mnn[local_id+i]);\n"
@ -140,7 +140,7 @@ const char* pooling =
" output_result /= (input_shape.x*input_shape.y);\n"
"#endif\n"
" WI_F(output,(int2)(output_channel_idx,output_batch_idx),output_result);\n"
" #if RETURN_REDICE\n"
" #ifdef RETURN_REDICE\n"
" redice=rediceId[0];\n"
" WI_F(rediceOutput,(int2)(output_channel_idx,output_batch_idx),CONVERT_FLOAT4(redice));\n"
" #endif\n"

View File

@ -69,28 +69,30 @@ __kernel void raster_direct_buffer(
int inputIndex = inputOffset + id * combineSrcOffset + z * inputStride0 + y * inputStride1 + x * inputStride2;
int outputIndex = outputOffset + id * combineDstOffset + z * outputStride0 + y * outputStride1 + x * outputStride2;
int inputIndexReal = 0;
int outputIndexReal = 0;
#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW
int inputIndexReal = inputIndex;
inputIndexReal = inputIndex;
#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC
int inputIndexReal = inputIndex;
inputIndexReal = inputIndex;
#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
int in_w = inputIndex % src_width; inputIndex /= src_width;
int in_h = inputIndex % src_height; inputIndex /= src_height;
int in_c = inputIndex % src_channel;
int in_b = inputIndex / src_channel;
int inputIndexReal = (((in_b + (in_c / 4) * src_batch) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4);
inputIndexReal = (((in_b + (in_c / 4) * src_batch) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4);
#endif
#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW
int outputIndexReal = outputIndex;
outputIndexReal = outputIndex;
#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC
int outputIndexReal = outputIndex;
outputIndexReal = outputIndex;
#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4
int out_w = outputIndex % dst_width; outputIndex /= dst_width;
int out_h = outputIndex % dst_height; outputIndex /= dst_height;
int out_c = outputIndex % dst_channel;
int out_b = outputIndex / dst_channel;
int outputIndexReal = (((out_b + (out_c / 4) * dst_batch) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4);
outputIndexReal = (((out_b + (out_c / 4) * dst_batch) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4);
#endif
output[outputIndexReal] = (OUTPUT_TYPE)input[inputIndexReal];
}

View File

@ -58,28 +58,30 @@ const char* raster_buf =
" \n"
" int inputIndex=inputOffset+id*combineSrcOffset+z*inputStride0+y*inputStride1+x*inputStride2;\n"
" int outputIndex=outputOffset+id*combineDstOffset+z*outputStride0+y*outputStride1+x*outputStride2;\n"
" int inputIndexReal=0;\n"
" int outputIndexReal=0;\n"
"#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int inputIndexReal=inputIndex;\n"
" inputIndexReal=inputIndex;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int inputIndexReal=inputIndex;\n"
" inputIndexReal=inputIndex;\n"
"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int in_w=inputIndex % src_width; inputIndex /= src_width;\n"
" int in_h=inputIndex % src_height; inputIndex /= src_height;\n"
" int in_c=inputIndex % src_channel;\n"
" int in_b=inputIndex/src_channel;\n"
" int inputIndexReal=(((in_b+(in_c/4)*src_batch)*src_height+in_h)*src_width+in_w)*4+(in_c % 4);\n"
" inputIndexReal=(((in_b+(in_c/4)*src_batch)*src_height+in_h)*src_width+in_w)*4+(in_c % 4);\n"
"#endif\n"
" \n"
"#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n"
" int outputIndexReal=outputIndex;\n"
" outputIndexReal=outputIndex;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n"
" int outputIndexReal=outputIndex;\n"
" outputIndexReal=outputIndex;\n"
"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n"
" int out_w=outputIndex % dst_width; outputIndex /= dst_width;\n"
" int out_h=outputIndex % dst_height; outputIndex /= dst_height;\n"
" int out_c=outputIndex % dst_channel;\n"
" int out_b=outputIndex/dst_channel;\n"
" int outputIndexReal=(((out_b+(out_c/4)*dst_batch)*dst_height+out_h)*dst_width+out_w)*4+(out_c % 4);\n"
" outputIndexReal=(((out_b+(out_c/4)*dst_batch)*dst_height+out_h)*dst_width+out_w)*4+(out_c % 4);\n"
"#endif\n"
" output[outputIndexReal]=(OUTPUT_TYPE)input[inputIndexReal];\n"
"}\n"

View File

@ -18,7 +18,7 @@ __private const int global_size_dim0, __private const int global_size_dim1, __pr
if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \
return; \
}
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;

View File

@ -12,6 +12,7 @@ const char* reduction =
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
" \n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void reduct_width(GLOBAL_SIZE_3_DIMS\n"
" __read_only image2d_t input,\n"

View File

@ -40,7 +40,7 @@ void ConvLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptr<ConvolutionC
int numAlpha = mResource->mOutputChannel;
mResource->mBlockSize = totalCount / numAlpha;
// set mDequantScale mDequantOffset
int numAlphaPack = ROUND_UP(numAlpha, 16);
int numAlphaPack = ROUND_UP(numAlpha, 4);
int mapSize = mResource->mBlockSize * numAlphaPack * sizeof(int32_t) * 2;
mResource->dequantScaleOffset.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, mapSize));
// transfer data from src in cpu to dst in gpu
@ -242,7 +242,7 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu
if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
}
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision());
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
@ -283,7 +283,7 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu
if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
}
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision());
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision());
uint32_t idx = 0;
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]);
ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]);
@ -347,7 +347,7 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o
if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){
buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
}
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision());
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
@ -391,7 +391,7 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o
if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){
buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT");
}
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision());
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision());
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
@ -445,7 +445,7 @@ void ConvLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * output)
if(inputChannels % 4 != 0){
buildOption.emplace("-DINPUT_CHANNEL_LEAVE");
}
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm", kernelname, buildOption, mOpenCLBackend->getPrecision());
unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_int", kernelname, buildOption, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel));
mGlobalWorkSize = {static_cast<uint32_t>(global_x), static_cast<uint32_t>(global_y)};
// MNN_PRINT("Kernel is %d.\n", min_index);
@ -516,13 +516,7 @@ ConvLowMemoryExecution::ConvLowMemoryExecution(const std::vector<Tensor *> &inpu
} else if (conv2dCommonParams->relu6()) {
mResource->mBuildOptions.emplace("-DRELU6");
}
if (mNumQuantBit == 8) {
// int8 case
mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8");
} else if (mNumQuantBit == 4){
// int4 case
mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4");
} else {/* More types to be supported. */}
mResource->mBuildOptions.emplace("-DQUANT_BIT=" + std::to_string(mNumQuantBit));
#ifdef LOG_VERBOSE
MNN_PRINT("end ConvExecution init !\n");
#endif

View File

@ -519,7 +519,7 @@ LoopBinaryExecution::LoopBinaryExecution(const LoopParam *loop, const std::strin
: CommonExecution(bn, op) {
mLoop = loop;
mTensors.resize(mLoop->tensorNumber());
mBuildOptions.emplace("-DLOOP_BINARY_OPERATOR=" + compute);
mBuildOptions.emplace("-DOPERATOR=" + compute);
}
ErrorCode LoopBinaryExecution::cumSumOnEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
@ -569,7 +569,6 @@ ErrorCode LoopBinaryExecution::cumSumOnEncode(const std::vector<Tensor *> &input
{
Unit unit;
std::set<std::string> buildOptions = mBuildOptions;
buildOptions.emplace("-DCOMPUTE_CUMSUM");
unit.kernel = runTime->buildKernel("loop", "loop_cumsum", buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]);
uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(unit.kernel));

View File

@ -80,7 +80,7 @@ ErrorCode PoolExecution::onEncode(const std::vector<Tensor *> &inputs, const std
std::set<std::string> buildOptions;
std::string kernelName = "pooling";
auto runtime = mOpenCLBackend->getOpenCLRuntime();
int local_size;
int local_size = 1;
if (mPoolParams->isGlobal()) {
std::vector<int> inputShape = tensorShapeFormat(inputs[0]);
@ -90,8 +90,8 @@ ErrorCode PoolExecution::onEncode(const std::vector<Tensor *> &inputs, const std
kernelName = "global_pooling";
auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize);
local_size = getLocalSize(inputShape.at(1) * inputShape.at(2), MaxLocalSize);
buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size));
}
buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size));
if (mPadType == PoolPadType_SAME) {
int padNeededHeight = std::max(0, (output->height() - 1) * mStrides[0] + mKernels[0] - input->height());

View File

@ -29,18 +29,12 @@ table GemmInfo {
paramInfo:[uint];
}
table PreParamInfo{
preParamName:string;
preParamData:uint;
}
table BackendInfo{
mnnVersion:string;
deviceName:string;
programs:[Shader];
tunings:[Autotuning];
gemm:[GemmInfo];
preParam:[PreParamInfo];
}
table Cache {

View File

@ -0,0 +1,141 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class Autotuning(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Autotuning()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsAutotuning(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Autotuning
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# Autotuning
def Key(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# Autotuning
def GloablSize(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# Autotuning
def GloablSizeAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0
# Autotuning
def GloablSizeLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Autotuning
def GloablSizeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
# Autotuning
def LocalSize(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# Autotuning
def LocalSizeAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0
# Autotuning
def LocalSizeLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Autotuning
def LocalSizeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
return o == 0
# Autotuning
def TimeCost(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0
def AutotuningStart(builder):
builder.StartObject(4)
def Start(builder):
AutotuningStart(builder)
def AutotuningAddKey(builder, key):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0)
def AddKey(builder, key):
AutotuningAddKey(builder, key)
def AutotuningAddGloablSize(builder, gloablSize):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(gloablSize), 0)
def AddGloablSize(builder, gloablSize):
AutotuningAddGloablSize(builder, gloablSize)
def AutotuningStartGloablSizeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartGloablSizeVector(builder, numElems):
return AutotuningStartGloablSizeVector(builder, numElems)
def AutotuningAddLocalSize(builder, localSize):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(localSize), 0)
def AddLocalSize(builder, localSize):
AutotuningAddLocalSize(builder, localSize)
def AutotuningStartLocalSizeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartLocalSizeVector(builder, numElems):
return AutotuningStartLocalSizeVector(builder, numElems)
def AutotuningAddTimeCost(builder, timeCost):
builder.PrependUint32Slot(3, timeCost, 0)
def AddTimeCost(builder, timeCost):
AutotuningAddTimeCost(builder, timeCost)
def AutotuningEnd(builder):
return builder.EndObject()
def End(builder):
return AutotuningEnd(builder)

View File

@ -0,0 +1,174 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class BackendInfo(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = BackendInfo()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsBackendInfo(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# BackendInfo
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# BackendInfo
def MnnVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# BackendInfo
def DeviceName(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# BackendInfo
def Programs(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.Shader import Shader
obj = Shader()
obj.Init(self._tab.Bytes, x)
return obj
return None
# BackendInfo
def ProgramsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.VectorLen(o)
return 0
# BackendInfo
def ProgramsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
return o == 0
# BackendInfo
def Tunings(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.Autotuning import Autotuning
obj = Autotuning()
obj.Init(self._tab.Bytes, x)
return obj
return None
# BackendInfo
def TuningsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.VectorLen(o)
return 0
# BackendInfo
def TuningsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
return o == 0
# BackendInfo
def Gemm(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.GemmInfo import GemmInfo
obj = GemmInfo()
obj.Init(self._tab.Bytes, x)
return obj
return None
# BackendInfo
def GemmLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.VectorLen(o)
return 0
# BackendInfo
def GemmIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
return o == 0
def BackendInfoStart(builder):
builder.StartObject(5)
def Start(builder):
BackendInfoStart(builder)
def BackendInfoAddMnnVersion(builder, mnnVersion):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(mnnVersion), 0)
def AddMnnVersion(builder, mnnVersion):
BackendInfoAddMnnVersion(builder, mnnVersion)
def BackendInfoAddDeviceName(builder, deviceName):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(deviceName), 0)
def AddDeviceName(builder, deviceName):
BackendInfoAddDeviceName(builder, deviceName)
def BackendInfoAddPrograms(builder, programs):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(programs), 0)
def AddPrograms(builder, programs):
BackendInfoAddPrograms(builder, programs)
def BackendInfoStartProgramsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartProgramsVector(builder, numElems):
return BackendInfoStartProgramsVector(builder, numElems)
def BackendInfoAddTunings(builder, tunings):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(tunings), 0)
def AddTunings(builder, tunings):
BackendInfoAddTunings(builder, tunings)
def BackendInfoStartTuningsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartTuningsVector(builder, numElems):
return BackendInfoStartTuningsVector(builder, numElems)
def BackendInfoAddGemm(builder, gemm):
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(gemm), 0)
def AddGemm(builder, gemm):
BackendInfoAddGemm(builder, gemm)
def BackendInfoStartGemmVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartGemmVector(builder, numElems):
return BackendInfoStartGemmVector(builder, numElems)
def BackendInfoEnd(builder):
return builder.EndObject()
def End(builder):
return BackendInfoEnd(builder)

View File

@ -0,0 +1,111 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class Cache(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Cache()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsCache(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Cache
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# Cache
def Backends(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.BackendInfo import BackendInfo
obj = BackendInfo()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Cache
def BackendsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Cache
def BackendsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
# Cache
def Tuned(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.OpInfo import OpInfo
obj = OpInfo()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Cache
def TunedLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Cache
def TunedIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
def CacheStart(builder):
builder.StartObject(2)
def Start(builder):
CacheStart(builder)
def CacheAddBackends(builder, backends):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(backends), 0)
def AddBackends(builder, backends):
CacheAddBackends(builder, backends)
def CacheStartBackendsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartBackendsVector(builder, numElems):
return CacheStartBackendsVector(builder, numElems)
def CacheAddTuned(builder, tuned):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(tuned), 0)
def AddTuned(builder, tuned):
CacheAddTuned(builder, tuned)
def CacheStartTunedVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartTunedVector(builder, numElems):
return CacheStartTunedVector(builder, numElems)
def CacheEnd(builder):
return builder.EndObject()
def End(builder):
return CacheEnd(builder)

View File

@ -0,0 +1,115 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class GemmInfo(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = GemmInfo()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsGemmInfo(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# GemmInfo
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# GemmInfo
def GemmSize(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# GemmInfo
def GemmSizeAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0
# GemmInfo
def GemmSizeLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0
# GemmInfo
def GemmSizeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
# GemmInfo
def ParamInfo(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# GemmInfo
def ParamInfoAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0
# GemmInfo
def ParamInfoLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0
# GemmInfo
def ParamInfoIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
def GemmInfoStart(builder):
builder.StartObject(2)
def Start(builder):
GemmInfoStart(builder)
def GemmInfoAddGemmSize(builder, gemmSize):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(gemmSize), 0)
def AddGemmSize(builder, gemmSize):
GemmInfoAddGemmSize(builder, gemmSize)
def GemmInfoStartGemmSizeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartGemmSizeVector(builder, numElems):
return GemmInfoStartGemmSizeVector(builder, numElems)
def GemmInfoAddParamInfo(builder, paramInfo):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(paramInfo), 0)
def AddParamInfo(builder, paramInfo):
GemmInfoAddParamInfo(builder, paramInfo)
def GemmInfoStartParamInfoVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartParamInfoVector(builder, numElems):
return GemmInfoStartParamInfoVector(builder, numElems)
def GemmInfoEnd(builder):
return builder.EndObject()
def End(builder):
return GemmInfoEnd(builder)

View File

@ -0,0 +1,137 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class OpInfo(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = OpInfo()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsOpInfo(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# OpInfo
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# OpInfo
def Name(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# OpInfo
def Type(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
# OpInfo
def Inputs(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.TensorInfo import TensorInfo
obj = TensorInfo()
obj.Init(self._tab.Bytes, x)
return obj
return None
# OpInfo
def InputsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.VectorLen(o)
return 0
# OpInfo
def InputsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
return o == 0
# OpInfo
def Outputs(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from CLCache.TensorInfo import TensorInfo
obj = TensorInfo()
obj.Init(self._tab.Bytes, x)
return obj
return None
# OpInfo
def OutputsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.VectorLen(o)
return 0
# OpInfo
def OutputsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
return o == 0
def OpInfoStart(builder):
builder.StartObject(4)
def Start(builder):
OpInfoStart(builder)
def OpInfoAddName(builder, name):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def AddName(builder, name):
OpInfoAddName(builder, name)
def OpInfoAddType(builder, type):
builder.PrependInt32Slot(1, type, 0)
def AddType(builder, type):
OpInfoAddType(builder, type)
def OpInfoAddInputs(builder, inputs):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
def AddInputs(builder, inputs):
OpInfoAddInputs(builder, inputs)
def OpInfoStartInputsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartInputsVector(builder, numElems):
return OpInfoStartInputsVector(builder, numElems)
def OpInfoAddOutputs(builder, outputs):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
def AddOutputs(builder, outputs):
OpInfoAddOutputs(builder, outputs)
def OpInfoStartOutputsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartOutputsVector(builder, numElems):
return OpInfoStartOutputsVector(builder, numElems)
def OpInfoEnd(builder):
return builder.EndObject()
def End(builder):
return OpInfoEnd(builder)

View File

@ -0,0 +1,115 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class Shader(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Shader()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsShader(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Shader
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# Shader
def Buffer(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
return 0
# Shader
def BufferAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
return 0
# Shader
def BufferLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Shader
def BufferIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
# Shader
def Program(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# Shader
def Kernel(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# Shader
def BuildInfo(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
def ShaderStart(builder):
builder.StartObject(4)
def Start(builder):
ShaderStart(builder)
def ShaderAddBuffer(builder, buffer):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(buffer), 0)
def AddBuffer(builder, buffer):
ShaderAddBuffer(builder, buffer)
def ShaderStartBufferVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
def StartBufferVector(builder, numElems):
return ShaderStartBufferVector(builder, numElems)
def ShaderAddProgram(builder, program):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(program), 0)
def AddProgram(builder, program):
ShaderAddProgram(builder, program)
def ShaderAddKernel(builder, kernel):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0)
def AddKernel(builder, kernel):
ShaderAddKernel(builder, kernel)
def ShaderAddBuildInfo(builder, buildInfo):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(buildInfo), 0)
def AddBuildInfo(builder, buildInfo):
ShaderAddBuildInfo(builder, buildInfo)
def ShaderEnd(builder):
return builder.EndObject()
def End(builder):
return ShaderEnd(builder)

View File

@ -0,0 +1,76 @@
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: CLCache
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class TensorInfo(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = TensorInfo()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsTensorInfo(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# TensorInfo
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# TensorInfo
def Shape(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# TensorInfo
def ShapeAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
return 0
# TensorInfo
def ShapeLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0
# TensorInfo
def ShapeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
def TensorInfoStart(builder):
builder.StartObject(1)
def Start(builder):
TensorInfoStart(builder)
def TensorInfoAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
def AddShape(builder, shape):
TensorInfoAddShape(builder, shape)
def TensorInfoStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def StartShapeVector(builder, numElems):
return TensorInfoStartShapeVector(builder, numElems)
def TensorInfoEnd(builder):
return builder.EndObject()
def End(builder):
return TensorInfoEnd(builder)

View File

@ -23,9 +23,6 @@ struct AutotuningT;
struct GemmInfo;
struct GemmInfoT;
struct PreParamInfo;
struct PreParamInfoT;
struct BackendInfo;
struct BackendInfoT;
@ -42,8 +39,6 @@ inline const flatbuffers::TypeTable *AutotuningTypeTable();
inline const flatbuffers::TypeTable *GemmInfoTypeTable();
inline const flatbuffers::TypeTable *PreParamInfoTypeTable();
inline const flatbuffers::TypeTable *BackendInfoTypeTable();
inline const flatbuffers::TypeTable *CacheTypeTable();
@ -430,71 +425,6 @@ inline flatbuffers::Offset<GemmInfo> CreateGemmInfo(
flatbuffers::Offset<GemmInfo> CreateGemmInfo(flatbuffers::FlatBufferBuilder &_fbb, const GemmInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct PreParamInfoT : public flatbuffers::NativeTable {
typedef PreParamInfo TableType;
std::string preParamName;
uint32_t preParamData;
PreParamInfoT()
: preParamData(0) {
}
};
struct PreParamInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef PreParamInfoT NativeTableType;
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
return PreParamInfoTypeTable();
}
const flatbuffers::String *preParamName() const {
return GetPointer<const flatbuffers::String *>(4);
}
uint32_t preParamData() const {
return GetField<uint32_t>(6, 0);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, 4) &&
verifier.VerifyString(preParamName()) &&
VerifyField<uint32_t>(verifier, 6) &&
verifier.EndTable();
}
PreParamInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(PreParamInfoT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<PreParamInfo> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct PreParamInfoBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_preParamName(flatbuffers::Offset<flatbuffers::String> preParamName) {
fbb_.AddOffset(4, preParamName);
}
void add_preParamData(uint32_t preParamData) {
fbb_.AddElement<uint32_t>(6, preParamData, 0);
}
explicit PreParamInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
PreParamInfoBuilder &operator=(const PreParamInfoBuilder &);
flatbuffers::Offset<PreParamInfo> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<PreParamInfo>(end);
return o;
}
};
inline flatbuffers::Offset<PreParamInfo> CreatePreParamInfo(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::String> preParamName = 0,
uint32_t preParamData = 0) {
PreParamInfoBuilder builder_(_fbb);
builder_.add_preParamData(preParamData);
builder_.add_preParamName(preParamName);
return builder_.Finish();
}
flatbuffers::Offset<PreParamInfo> CreatePreParamInfo(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct BackendInfoT : public flatbuffers::NativeTable {
typedef BackendInfo TableType;
std::string mnnVersion;
@ -502,7 +432,6 @@ struct BackendInfoT : public flatbuffers::NativeTable {
std::vector<std::unique_ptr<ShaderT>> programs;
std::vector<std::unique_ptr<AutotuningT>> tunings;
std::vector<std::unique_ptr<GemmInfoT>> gemm;
std::vector<std::unique_ptr<PreParamInfoT>> preParam;
BackendInfoT() {
}
};
@ -527,9 +456,6 @@ struct BackendInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<flatbuffers::Offset<GemmInfo>> *gemm() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<GemmInfo>> *>(12);
}
const flatbuffers::Vector<flatbuffers::Offset<PreParamInfo>> *preParam() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<PreParamInfo>> *>(14);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, 4) &&
@ -545,9 +471,6 @@ struct BackendInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, 12) &&
verifier.VerifyVector(gemm()) &&
verifier.VerifyVectorOfTables(gemm()) &&
VerifyOffset(verifier, 14) &&
verifier.VerifyVector(preParam()) &&
verifier.VerifyVectorOfTables(preParam()) &&
verifier.EndTable();
}
BackendInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -573,9 +496,6 @@ struct BackendInfoBuilder {
void add_gemm(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<GemmInfo>>> gemm) {
fbb_.AddOffset(12, gemm);
}
void add_preParam(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<PreParamInfo>>> preParam) {
fbb_.AddOffset(14, preParam);
}
explicit BackendInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@ -594,10 +514,8 @@ inline flatbuffers::Offset<BackendInfo> CreateBackendInfo(
flatbuffers::Offset<flatbuffers::String> deviceName = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Shader>>> programs = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Autotuning>>> tunings = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<GemmInfo>>> gemm = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<PreParamInfo>>> preParam = 0) {
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<GemmInfo>>> gemm = 0) {
BackendInfoBuilder builder_(_fbb);
builder_.add_preParam(preParam);
builder_.add_gemm(gemm);
builder_.add_tunings(tunings);
builder_.add_programs(programs);
@ -835,35 +753,6 @@ inline flatbuffers::Offset<GemmInfo> CreateGemmInfo(flatbuffers::FlatBufferBuild
_paramInfo);
}
inline PreParamInfoT *PreParamInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new PreParamInfoT();
UnPackTo(_o, _resolver);
return _o;
}
inline void PreParamInfo::UnPackTo(PreParamInfoT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = preParamName(); if (_e) _o->preParamName = _e->str(); };
{ auto _e = preParamData(); _o->preParamData = _e; };
}
inline flatbuffers::Offset<PreParamInfo> PreParamInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreatePreParamInfo(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<PreParamInfo> CreatePreParamInfo(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PreParamInfoT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _preParamName = _o->preParamName.empty() ? 0 : _fbb.CreateString(_o->preParamName);
auto _preParamData = _o->preParamData;
return CLCache::CreatePreParamInfo(
_fbb,
_preParamName,
_preParamData);
}
inline BackendInfoT *BackendInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BackendInfoT();
UnPackTo(_o, _resolver);
@ -878,7 +767,6 @@ inline void BackendInfo::UnPackTo(BackendInfoT *_o, const flatbuffers::resolver_
{ auto _e = programs(); if (_e) { _o->programs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->programs[_i] = std::unique_ptr<ShaderT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = tunings(); if (_e) { _o->tunings.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tunings[_i] = std::unique_ptr<AutotuningT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = gemm(); if (_e) { _o->gemm.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->gemm[_i] = std::unique_ptr<GemmInfoT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = preParam(); if (_e) { _o->preParam.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->preParam[_i] = std::unique_ptr<PreParamInfoT>(_e->Get(_i)->UnPack(_resolver)); } } };
}
inline flatbuffers::Offset<BackendInfo> BackendInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BackendInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -894,15 +782,13 @@ inline flatbuffers::Offset<BackendInfo> CreateBackendInfo(flatbuffers::FlatBuffe
auto _programs = _o->programs.size() ? _fbb.CreateVector<flatbuffers::Offset<Shader>> (_o->programs.size(), [](size_t i, _VectorArgs *__va) { return CreateShader(*__va->__fbb, __va->__o->programs[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _tunings = _o->tunings.size() ? _fbb.CreateVector<flatbuffers::Offset<Autotuning>> (_o->tunings.size(), [](size_t i, _VectorArgs *__va) { return CreateAutotuning(*__va->__fbb, __va->__o->tunings[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _gemm = _o->gemm.size() ? _fbb.CreateVector<flatbuffers::Offset<GemmInfo>> (_o->gemm.size(), [](size_t i, _VectorArgs *__va) { return CreateGemmInfo(*__va->__fbb, __va->__o->gemm[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _preParam = _o->preParam.size() ? _fbb.CreateVector<flatbuffers::Offset<PreParamInfo>> (_o->preParam.size(), [](size_t i, _VectorArgs *__va) { return CreatePreParamInfo(*__va->__fbb, __va->__o->preParam[i].get(), __va->__rehasher); }, &_va ) : 0;
return CLCache::CreateBackendInfo(
_fbb,
_mnnVersion,
_deviceName,
_programs,
_tunings,
_gemm,
_preParam);
_gemm);
}
inline CacheT *Cache::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -1022,46 +908,28 @@ inline const flatbuffers::TypeTable *GemmInfoTypeTable() {
return &tt;
}
inline const flatbuffers::TypeTable *PreParamInfoTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_STRING, 0, -1 },
{ flatbuffers::ET_UINT, 0, -1 }
};
static const char * const names[] = {
"preParamName",
"preParamData"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 2, type_codes, nullptr, nullptr, names
};
return &tt;
}
inline const flatbuffers::TypeTable *BackendInfoTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_STRING, 0, -1 },
{ flatbuffers::ET_STRING, 0, -1 },
{ flatbuffers::ET_SEQUENCE, 1, 0 },
{ flatbuffers::ET_SEQUENCE, 1, 1 },
{ flatbuffers::ET_SEQUENCE, 1, 2 },
{ flatbuffers::ET_SEQUENCE, 1, 3 }
{ flatbuffers::ET_SEQUENCE, 1, 2 }
};
static const flatbuffers::TypeFunction type_refs[] = {
ShaderTypeTable,
AutotuningTypeTable,
GemmInfoTypeTable,
PreParamInfoTypeTable
GemmInfoTypeTable
};
static const char * const names[] = {
"mnnVersion",
"deviceName",
"programs",
"tunings",
"gemm",
"preParam"
"gemm"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 6, type_codes, type_refs, nullptr, names
flatbuffers::ST_TABLE, 5, type_codes, type_refs, nullptr, names
};
return &tt;
}

View File

@ -0,0 +1,80 @@
import flatbuffers
from CLCache.Cache import Cache
from CLCache.BackendInfo import BackendInfo
def generate_cpp_header(buffer):
cache = Cache.GetRootAs(buffer, 0)
backends_len = cache.BackendsLength()
opencl_tune_map = {}
# 读取 BackendInfo 信息
for i in range(backends_len):
backend_info = cache.Backends(i)
device_name = backend_info.DeviceName().decode("utf-8")
autotuning_map = {}
tunings_len = backend_info.TuningsLength()
for j in range(tunings_len):
tuning = backend_info.Tunings(j)
key = tuning.Key().decode("utf-8")
global_size = list(tuning.GloablSize(j) for j in range(tuning.GloablSizeLength()))
local_size =list(tuning.LocalSize(j) for j in range(tuning.LocalSizeLength()))
if key not in autotuning_map:
autotuning_map[key] = []
autotuning_map[key].append((global_size, local_size))
gemm_len = backend_info.GemmLength()
for j in range(gemm_len):
gemm = backend_info.Gemm(j)
key = 'Xgemm_tune'
gemm_size = list(gemm.GemmSize(j) for j in range(gemm.GemmSizeLength()))
param_info =list(gemm.ParamInfo(j) for j in range(gemm.ParamInfoLength()))
if key not in autotuning_map:
autotuning_map[key] = []
autotuning_map[key].append((gemm_size, param_info))
opencl_tune_map[device_name] = autotuning_map
# 生成 C++ 代码字符串
cpp_code = "const std::map<std::string, std::map<std::string, std::vector<std::pair<std::vector<uint32_t>, std::vector<uint32_t>>>>> OpenCLTuneMap = {\n"
for device, tuning_map in opencl_tune_map.items():
cpp_code += f' {{"{device}", {{\n'
for key, size_pairs in tuning_map.items():
cpp_code += f' {{"{key}", {{\n'
for sizes in size_pairs:
cpp_code += f' {{{{ {", ".join(map(str, sizes[0]))} }}, {{ {", ".join(map(str, sizes[1]))} }}}},\n'
cpp_code += " }},\n"
cpp_code += " }},\n"
cpp_code += "};\n"
return cpp_code
if __name__ == '__main__':
import sys
if len(sys.argv) != 2:
print("Usage: python import_cache.py <cache_file>")
print("Example: python merge_cache.py mnn_cachefile.bin")
sys.exit(1)
with open(sys.argv[1], 'rb') as f:
buffer = f.read()
cpp_header_code = generate_cpp_header(buffer)
# 将结果保存到头文件中
with open('OpenCLTuneMap.hpp', 'w') as header_file:
header_file.write("#include <map>\n#include <string>\n#include <vector>\n\nnamespace MNN { \n")
header_file.write(cpp_header_code)
header_file.write("\n}\n")
print("C++ header file generated.")

View File

@ -0,0 +1,160 @@
import flatbuffers
from CLCache import Cache, BackendInfo, Autotuning, GemmInfo
def load_backend_infos(file_path):
with open(file_path, 'rb') as f:
buf = bytearray(f.read())
cache = Cache.Cache.GetRootAs(buf, 0)
backends = []
for i in range(cache.BackendsLength()):
backend = cache.Backends(i)
backends.append(backend)
return backends
def load_tune_infos(backends):
original_map = {}
for backend in backends:
mnn_version = backend.MnnVersion()
device_name = backend.DeviceName()
tunings = {}
for i in range(backend.TuningsLength()):
tune = backend.Tunings(i)
key = tune.Key()
global_size = [tune.GloablSize(j) for j in range(tune.GloablSizeLength())]
local_size = [tune.LocalSize(j) for j in range(tune.LocalSizeLength())]
cost_time = tune.TimeCost()
tunings[(key, tuple(global_size))] = (local_size, cost_time)
#gemm tune info
for i in range(backend.GemmLength()):
tune = backend.Gemm(i)
key = 'Xgemm_tune'
gemm_size = [tune.GemmSize(j) for j in range(tune.GemmSizeLength())]
param_info = [tune.ParamInfo(j) for j in range(tune.ParamInfoLength())]
tunings[(key, tuple(gemm_size))] = (param_info, 0)
original_map[(mnn_version, device_name)] = tunings
return original_map
def create_backend_info(new_backends, original_backends):
original_map = load_tune_infos(original_backends)
new_map = load_tune_infos(new_backends)
for ver_dev in new_map:
if ver_dev in original_map:
new_tune = new_map[ver_dev]
original_tune = original_map[ver_dev]
for key in new_tune:
if key not in original_tune:
original_tune[key] = new_tune[key]
else:
original_map[ver_dev] = new_map[ver_dev]
return original_map
def build_cache(nested_dict):
"""将嵌套字典转换为 FlatBuffers 的 Cache 结构"""
builder = flatbuffers.Builder()
# ====================== 构建 BackendInfo 列表 ======================
backend_offsets = []
for (mnn_ver, device_name), autotune_dict in nested_dict.items():
# 构建字符串
mnn_ver_offset = builder.CreateString(mnn_ver)
device_name_offset = builder.CreateString(device_name)
# 构建 Autotuning 条目
tuning_offsets = []
gemm_offsets = []
for (key, global_size), (local_size, time_cost) in autotune_dict.items():
if key == 'Xgemm_tune':
# 构建 GemmSize 向量 (倒序填充)
GemmInfo.GemmInfoStartGemmSizeVector(builder, len(global_size))
for n in reversed(global_size):
builder.PrependUint32(n)
global_size_offset = builder.EndVector()
# 构建 ParamInfo 向量 (倒序填充)
GemmInfo.GemmInfoStartParamInfoVector(builder, len(local_size))
for n in reversed(local_size):
builder.PrependUint32(n)
local_size_offset = builder.EndVector()
# 构建 Autotuning 对象
GemmInfo.GemmInfoStart(builder)
GemmInfo.GemmInfoAddGemmSize(builder, global_size_offset)
GemmInfo.GemmInfoAddParamInfo(builder, local_size_offset)
gemm_offsets.append(GemmInfo.GemmInfoEnd(builder))
else:
# 构建字符串
key_offset = builder.CreateString(key)
# 构建 globalSize 向量 (倒序填充)
Autotuning.AutotuningStartGloablSizeVector(builder, len(global_size))
for n in reversed(global_size):
builder.PrependUint32(n)
global_size_offset = builder.EndVector()
# 构建 localSize 向量 (倒序填充)
Autotuning.AutotuningStartLocalSizeVector(builder, len(local_size))
for n in reversed(local_size):
builder.PrependUint32(n)
local_size_offset = builder.EndVector()
# 构建 Autotuning 对象
Autotuning.AutotuningStart(builder)
Autotuning.AutotuningAddKey(builder, key_offset)
Autotuning.AutotuningAddGloablSize(builder, global_size_offset)
Autotuning.AutotuningAddLocalSize(builder, local_size_offset)
Autotuning.AutotuningAddTimeCost(builder, time_cost)
Autotuning.AutotuningAddTimeCost(builder, 0)
tuning_offsets.append(Autotuning.AutotuningEnd(builder))
# 构建 tunings 向量
BackendInfo.BackendInfoStartTuningsVector(builder, len(tuning_offsets))
for offset in reversed(tuning_offsets):
builder.PrependUOffsetTRelative(offset)
tunings_offset = builder.EndVector()
# 构建 gemm 向量
BackendInfo.BackendInfoStartGemmVector(builder, len(gemm_offsets))
for offset in reversed(gemm_offsets):
builder.PrependUOffsetTRelative(offset)
gemm_offsets = builder.EndVector()
# 构建 BackendInfo
BackendInfo.BackendInfoStart(builder)
BackendInfo.BackendInfoAddMnnVersion(builder, mnn_ver_offset)
BackendInfo.BackendInfoAddDeviceName(builder, device_name_offset)
BackendInfo.BackendInfoAddTunings(builder, tunings_offset)
BackendInfo.BackendInfoAddGemm(builder, gemm_offsets)
backend_offsets.append(BackendInfo.BackendInfoEnd(builder))
# ====================== 构建最终 Cache ======================
# 构建 backends 向量
Cache.CacheStartBackendsVector(builder, len(backend_offsets))
for offset in reversed(backend_offsets):
builder.PrependUOffsetTRelative(offset)
backends_offset = builder.EndVector()
# 构建根对象
Cache.CacheStart(builder)
Cache.CacheAddBackends(builder, backends_offset)
cache = Cache.CacheEnd(builder)
builder.Finish(cache)
return builder.Output()
if __name__ == '__main__':
import sys
if len(sys.argv) != 4:
print("Usage: python merge_cache.py <primary_file> <total_file> <output_file>")
print("Example: python merge_cache.py mnn_cachefile.bin mnn_cachefile_total.bin new_cache.bin")
sys.exit(1)
original_backends = load_backend_infos(sys.argv[1])
new_backends = load_backend_infos(sys.argv[2])
original_map = create_backend_info(new_backends, original_backends)
#print(original_map)
binary_data = build_cache(original_map)
with open(sys.argv[3], "wb") as f:
f.write(binary_data)

View File

@ -19,8 +19,6 @@ class DepthToSpaceSizeComputer : public SizeComputer {
MNN_ASSERT(outputs.size() == 1);
MNN_ASSERT(inputs[0]->buffer().dimensions == 4);
// here only implement NHWC
// TODO: implement NC4HW4
const int blockSize = op->main_as_DepthSpaceParam()->blockSize();
MNN_ASSERT(blockSize >= 1);
auto format = TensorUtils::getDescribe(inputs[0])->dimensionFormat;

View File

@ -19,10 +19,8 @@ class SpaceToDepthSizeComputer : public SizeComputer {
MNN_ASSERT(outputs.size() == 1);
MNN_ASSERT(inputs[0]->buffer().dimensions == 4);
// here only implement NHWC
// TODO: implement NC4HW4
const int blockSize = op->main_as_DepthSpaceParam()->blockSize();
MNN_ASSERT(blockSize > 1);
MNN_ASSERT(blockSize >= 1);
auto& ib = inputs[0]->buffer();
auto& ob = outputs[0]->buffer();

View File

@ -17,6 +17,10 @@ IF(CMAKE_SYSTEM_NAME MATCHES "^Android")
target_link_libraries(GpuInterTest.out android)
list(APPEND MNN_CPP_TOOLS GpuInterTest.out)
ENDIF()
add_executable(OpenCLProgramBuildTest.out ${CMAKE_CURRENT_LIST_DIR}/OpenCLProgramBuildTest.cpp )
list(APPEND MNN_CPP_TOOLS OpenCLProgramBuildTest.out)
add_executable(SequenceModuleTest.out ${CMAKE_CURRENT_LIST_DIR}/SequenceModuleTest.cpp)
list(APPEND MNN_CPP_TOOLS SequenceModuleTest.out)

View File

@ -0,0 +1,244 @@
//
// OpenCLProgramBuildTest.cpp
// MNN
//
// Created by MNN on 2025/5/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <fstream>
#include <string>
#include <vector>
#include "CL/cl.h"
#ifdef _WIN32
#include <windows.h>
#include <libloaderapi.h>
#else
#include <dlfcn.h>
#endif
using clGetPlatformIDsFunc = cl_int (CL_API_CALL *)(cl_uint, cl_platform_id *, cl_uint *);
using clBuildProgramFunc = cl_int (CL_API_CALL *)(cl_program, cl_uint, const cl_device_id *, const char *, void (CL_CALLBACK *pfn_notify)(cl_program, void *), void *);
using clCreateProgramWithSourceFunc = cl_program (CL_API_CALL *)(cl_context, cl_uint, const char **, const size_t *, cl_int *);
using clGetProgramBuildInfoFunc = cl_int (CL_API_CALL *)(cl_program, cl_device_id, cl_program_build_info, size_t, void *, size_t *);
using clCreateContextFunc = cl_context (CL_API_CALL *)(const cl_context_properties *, cl_uint, const cl_device_id *,
void(CL_CALLBACK *)( // NOLINT(readability/casting)
const char *, const void *, size_t, void *),
void *, cl_int *);
using clGetDeviceIDsFunc = cl_int (CL_API_CALL *)(cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *);
using clGetDeviceInfoFunc = cl_int (CL_API_CALL *)(cl_device_id, cl_device_info, size_t, void *, size_t *);
using clReleaseProgramFunc = cl_int (CL_API_CALL *)(cl_program program);
using clReleaseContextFunc = cl_int (CL_API_CALL *)(cl_context);
using clReleaseDeviceFunc = cl_int (CL_API_CALL *)(cl_device_id);
class OpenCLProgramTest {
public:
OpenCLProgramTest(){
static const std::vector<std::string> gOpencl_library_paths = {
#if defined(__APPLE__) || defined(__MACOSX)
"libOpenCL.so", "/System/Library/Frameworks/OpenCL.framework/OpenCL"
#elif defined(__OHOS__)
"/vendor/lib64/chipsetsdk/libhvgr_v200.so",
"/vendor/lib64/chipsetsdk/libGLES_mali.so",
"/system/lib64/libGLES_mali.so",
"libGLES_mali.so",
"/vendor/lib64/chipsetsdk/libEGI_imp1.so",
#elif defined(__ANDROID__)
"libOpenCL.so",
"libGLES_mali.so",
"libmali.so",
"libOpenCL-pixel.so",
#if defined(__aarch64__)
// Qualcomm Adreno
"/system/vendor/lib64/libOpenCL.so",
"/system/lib64/libOpenCL.so",
// Mali
"/system/vendor/lib64/egl/libGLES_mali.so",
"/system/lib64/egl/libGLES_mali.so",
#else
// Qualcomm Adreno
"/system/vendor/lib/libOpenCL.so", "/system/lib/libOpenCL.so",
// Mali
"/system/vendor/lib/egl/libGLES_mali.so", "/system/lib/egl/libGLES_mali.so",
// other
"/system/vendor/lib/libPVROCL.so", "/data/data/org.pocl.libs/files/lib/libpocl.so"
#endif
#elif defined(__linux__)
"/usr/lib/libOpenCL.so",
"/usr/local/lib/libOpenCL.so",
"/usr/local/lib/libpocl.so",
"/usr/lib64/libOpenCL.so",
"/usr/lib32/libOpenCL.so",
"libOpenCL.so"
#elif defined(_WIN64)
"C:/Windows/System32/OpenCL.dll",
"C:/Windows/SysWOW64/OpenCL.dll"
#elif defined(_WIN32)
"C:/Windows/SysWOW64/OpenCL.dll",
"C:/Windows/System32/OpenCL.dll"
#endif
};
for (const auto &opencl_lib : gOpencl_library_paths) {
if (LoadLibraryFromPath(opencl_lib)) {
mIsSupportAvailable = true;
}
}
if(mIsSupportAvailable){
cl_int err;
err = clGetPlatformIDs(1, &platform, NULL);
if (err != CL_SUCCESS) {
printf("Failed to get platform ID err = %d\n", err);
return ;
}
err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, NULL);
if (err != CL_SUCCESS) {
printf("Failed to get device ID err = %d\n", err);
return;
}
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
if (!context || err != CL_SUCCESS) {
printf("Failed to create context err = %d\n", err);
return;
}
}
}
bool LoadLibraryFromPath(const std::string &library_path){
#if defined(_WIN32)
handle_ = LoadLibraryA(library_path.c_str());
if (handle_ == nullptr) {
return false;
}
#define MNN_LOAD_FUNCTION_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(GetProcAddress(static_cast<HMODULE>(handle_), #func_name));
#else
handle_ = dlopen(library_path.c_str(), RTLD_NOW | RTLD_LOCAL);
if (handle_ == nullptr) {
return false;
}
#define MNN_LOAD_FUNCTION_PTR(func_name) func_name = reinterpret_cast<func_name##Func>(dlsym(handle_, #func_name));
#endif
MNN_LOAD_FUNCTION_PTR(clGetPlatformIDs);
MNN_LOAD_FUNCTION_PTR(clBuildProgram);
MNN_LOAD_FUNCTION_PTR(clCreateProgramWithSource);
MNN_LOAD_FUNCTION_PTR(clGetProgramBuildInfo);
MNN_LOAD_FUNCTION_PTR(clCreateContext);
MNN_LOAD_FUNCTION_PTR(clGetDeviceIDs);
MNN_LOAD_FUNCTION_PTR(clGetDeviceInfo);
MNN_LOAD_FUNCTION_PTR(clReleaseProgram);
MNN_LOAD_FUNCTION_PTR(clReleaseContext);
MNN_LOAD_FUNCTION_PTR(clReleaseDevice);
return true;
}
bool TestProgram(const std::vector<std::string> options){
cl_int err;
FILE* file = fopen("kernel.cl", "r");
if (!file) {
printf("Failed to open kernel file: kernel.cl\n");
return false;
}
fseek(file, 0, SEEK_END);
size_t fileSize = ftell(file);
rewind(file);
char* source = (char*)malloc(fileSize + 1);
if (!source) {
fclose(file);
printf("Memory allocation failed for kernel source\n");
return false;
}
fread(source, sizeof(char), fileSize, file);
source[fileSize] = '\0';
fclose(file);
// test program
const char *code = source;
cl_program program = clCreateProgramWithSource(context, 1, &code, &fileSize, &err);
if (!program || err != CL_SUCCESS) {
printf("Failed to create program from source\n");
return false;
}
for(int i = 0; i < options.size(); ++i){
err = clBuildProgram(program, 1, &device, options[i].c_str(), NULL, NULL);
if (err != CL_SUCCESS) {
size_t logSize;
clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &logSize);
char *buildLog = (char*)malloc(logSize);
clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, logSize, buildLog, NULL);
printf("Program build log: ");
for (int i = 0; i < logSize; i++) {
printf("%c", buildLog[i]);
}
clReleaseProgram(program);
free(buildLog);
return false;
}
}
clReleaseProgram(program);
free(source);
return true;
}
bool mIsSupportAvailable = false;
~OpenCLProgramTest(){
if(mIsSupportAvailable){
clReleaseDevice(device);
clReleaseContext(context);
}
if (handle_ != nullptr) {
#if defined(_WIN32)
FreeLibrary(static_cast<HMODULE>(handle_));
#else
dlclose(handle_);
#endif
}
}
private:
void *handle_ = nullptr;
clGetPlatformIDsFunc clGetPlatformIDs;
clBuildProgramFunc clBuildProgram;
clCreateProgramWithSourceFunc clCreateProgramWithSource;
clGetProgramBuildInfoFunc clGetProgramBuildInfo;
clCreateContextFunc clCreateContext;
clGetDeviceIDsFunc clGetDeviceIDs;
clGetDeviceInfoFunc clGetDeviceInfo;
clReleaseProgramFunc clReleaseProgram;
clReleaseContextFunc clReleaseContext;
clReleaseDeviceFunc clReleaseDevice;
cl_platform_id platform;
cl_device_id device;
cl_context context;
};
int main(int argc, char *argv[]) {
std::string filename;
if(argc > 1){
filename = argv[1];
}
std::vector<std::string> options;
std::fstream file("option.txt");
if(file.is_open()){
std::string line;
while (getline(file, line)) { // 按行读取文件内容并输出
options.push_back(line);
}
file.close();
}
printf("test filename is %s\n", filename.c_str());
OpenCLProgramTest BuildTest;
if(BuildTest.mIsSupportAvailable){
if(BuildTest.TestProgram(options)){
return 0;
}
}else{
printf("OpenCL init fail\n");
return -1;
}
return 0;
}

View File

@ -0,0 +1,205 @@
import sys
import os
import re
import itertools
def run_cmd(args):
from subprocess import Popen, PIPE, STDOUT
stdout, _ = Popen(args, stdout=PIPE, stderr=STDOUT).communicate()
return stdout.decode('utf-8')
def extract_macros(file_content):
"""提取宏定义"""
macros = {}
macros_num = {}
ifdef_pattern = re.compile(r'#(ifdef)\s+(\w+)')
ifndef_pattern = re.compile(r'#(ifndef)\s+(\w+)')
if_pattern = re.compile(r'#(if)\s+(\w+)')
elif_pattern = re.compile(r'#(elif)\s+(\w+)')
defined_pattern = re.compile(r'(defined)\s+(\w+)')
define_pattern = re.compile(r'#(define)\s+(\w+)')
for match in ifdef_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if "LOCAL_SIZE" in macro_name:
macros_num[macro_name] = {1, 2, 3, 4, 16}
else:
macros[macro_name] = None
for match in ifndef_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if "LOCAL_SIZE" in macro_name:
macros_num[macro_name] = {1, 2, 3, 4, 16}
else:
macros[macro_name] = None
for match in if_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if macro_name != "defined":
macros_num[macro_name] = {1, 2, 3, 4, 8}
for match in elif_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if macro_name != "defined":
macros_num[macro_name] = {1, 2, 3, 4, 8}
for match in defined_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
macros[macro_name] = None
for match in define_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if macro_name in macros:
del macros[macro_name]
if macro_name in macros_num:
del macros_num[macro_name]
if "MNN_SUPPORT_FP16" in macros:
del macros["MNN_SUPPORT_FP16"]
#for macro_name, macro_value in macros.items():
# Replace macro value
#print(f"macro_name {macro_name} macro_value {macro_value}")
return [macros_num, macros]
def compile_with_macros(macros_all, operator_macro, extra_macro, filename, test_for_android):
"""
Tries to compile the kernel given various macro values
"""
macros_num = macros_all[0]
macros = macros_all[1]
float_option = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT=convert_float -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT3=convert_float3 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT=convert_float -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT3=convert_float3 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16"
float_option += " -DINPUT_TYPE_I=float -DINPUT_TYPE_I4=float4 -DINPUT_TYPE=float -DINPUT_TYPE4=float4 -DINPUT_TYPE16=float16 -DRI_DATA=read_imagef -DOUTPUT_TYPE_I=float -DOUTPUT_TYPE_I4=float4 -DCONVERT_OUTPUT_I4=convert_float4 -DOUTPUT_TYPE=float -DOUTPUT_TYPE4=float4 -DOUTPUT_TYPE16=float16 -DCONVERT_OUTPUT4=convert_float4 -DCONVERT_OUTPUT16=convert_float16 -DWI_DATA=write_imagef"
if filename in extra_macro:
float_option += extra_macro[filename]
keys = list(macros.keys())
# 使用 itertools.product 生成所有可能的 0 和 1 的组合
combinations = list(itertools.product([0, 1], repeat=len(keys)))
options_normal = []
# 获取普通的宏定义
for combination in combinations:
option_normal = float_option
macros_out = dict(zip(keys, combination))
for macro_name, macro_value in macros_out.items():
if macro_value == 1:
option_normal += f" -D{macro_name}={macro_value} "
options_normal.append(option_normal)
options_num_normal = []
# 获取有多种取值的宏
if len(macros_num) > 0 :
option_num = ""
for i in {1, 2, 3, 4, 8} :
for macro_name in macros_num:
option_num = f" -D{macro_name}={i} "
for option_normal in options_normal:
options_num_normal.append(option_normal + option_num)
else:
options_num_normal = options_normal
options = []
# 获取OPERATOR的宏, 只需要验证第一个OPERATOR宏与其他宏的各种组合其他的可以只验证一种组合
if len(operator_macro) > 0 :
has_combine = False
for op in operator_macro:
option_operator = f" -DOPERATOR={op} "
if has_combine is True:
options.append(options_num_normal[0] + option_operator)
else:
for option_num_normal in options_num_normal:
options.append(option_num_normal + option_operator)
has_combine = True
else:
options = options_num_normal
with open('option.txt', 'w') as outfile:
for option in options:
outfile.write(option + '\n')
if test_for_android == 1:
run_cmd(['adb', 'push', 'kernel.cl', '/data/local/tmp/MNN'])
run_cmd(['adb', 'push', 'option.txt', '/data/local/tmp/MNN'])
run_cmd(['adb', 'push', 'OpenCLProgramBuildTest.out', '/data/local/tmp/MNN'])
res = run_cmd(['adb', 'shell', 'cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH && ./OpenCLProgramBuildTest.out %s'%(filename)])
print(res)
else:
if sys.platform.startswith('win'):
res = run_cmd(['OpenCLProgramBuildTest.exe', f'{filename}'])
print(res)
else:
res = run_cmd(['./OpenCLProgramBuildTest.out', f'{filename}'])
print(res)
def main():
print("opencl_kernel_check.py path without_subgroup test_for_android")
path = '.'
without_subgroup = 1
test_for_android = 0
if len(sys.argv) > 1:
path = sys.argv[1]
if len(sys.argv) > 2:
without_subgroup = int(sys.argv[2])
if len(sys.argv) > 3:
test_for_android = int(sys.argv[3])
binaryvec_operator = {"in0+in1", "in0*in1", "in0-in1", "in0>in1?in0:in1", "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))",
"in0>in1?in0:in1", "convert_float4(-isgreater(in0,in1))", "convert_float4(-isless(in0,in1))", "convert_float4(-islessequal(in0,in1))", "convert_float4(-isgreaterequal(in0,in1))", "convert_float4(-isequal(in0,in1))",
"floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1",
"pow(in0,in1)", "(in0-in1)*(in0-in1)", "(in1==(float)0?(sign(in0)*(float4)(PI/2)):(atan(in0/in1)+(in1>(float4)0?(float4)0:sign(in0)*(float)PI)))", "convert_float4(-isnotequal(in0,in1))",
"in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1"}
binary_operator = {"in0*in1", "in0+in1", "in0-in1", "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))", "in0>in1?in1:in0", "in0>in1?in0:in1", "(float)(isgreater(in0,in1))",
"(float)(isless(in0,in1))", "(float)(islessequal(in0,in1))", "(float)(isgreaterequal(in0,in1))", "(float)(isequal(in0,in1))", "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))",
"in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", "pow(in0,in1)", "(in0-in1)*(in0-in1)",
"(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))", "(float)(isnotequal(in0,in1))", "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1"}
unary_operator = {"fabs(convert_float4(in))", "in*in", "rsqrt(convert_float4(in)>(float4)(0.000001)?convert_float4(in):(float4)(0.000001))", "-(in)", "exp(convert_float4(in))", "cos(convert_float4(in))", "sin(convert_float4(in))",
"tan(convert_float4(in))", "atan(convert_float4(in))", "sqrt(convert_float4(in))", "ceil(convert_float4(in))", "native_recip(convert_float4(in))", "log1p(convert_float4(in))", "native_log(convert_float4(in)>(float4)(0.0000001)?convert_float4(in):(float4)(0.0000001))",
"floor(convert_float4(in))", "in>(float4)((float)0)?(in+native_log(exp(convert_float4(-(in)))+(float4)(1.0))):(native_log(exp(convert_float4(in))+(float4)(1.0)))", "acosh(convert_float4(in))", "sinh(convert_float4(in))", "asinh(convert_float4(in))",
"atanh(convert_float4(in))", "sign(convert_float4(in))", "round(convert_float4(in))", "cosh(convert_float4(in))", "erf(convert_float4(in))", "erfc(convert_float4(in))", "expm1(convert_float4(in))", "native_recip((float4)1+native_exp(convert_float4(-in)))",
"(convert_float4(in)*native_recip((float4)1+native_exp(convert_float4(-in))))", "tanh(convert_float4(in))", "convert_float4(in)>(float4)(-3.0f)?(convert_float4(in)<(float4)(3.0f)?((convert_float4(in)*(convert_float4(in)+(float4)3.0f))/(float4)6.0f):convert_float4(in)):(float4)(0.0f)",
"gelu(convert_float4(in))", "(erf(convert_float4(in)*(float4)0.7071067932881648)+(float4)1.0)*convert_float4(in)*(float4)0.5", "native_recip((float4)(1.0)+native_exp(convert_float4(-(in))))",
"tanh(convert_float4(in))"}
extra_macro = {}
extra_macro["binary_subgroup_buf.cl"] = " -DINTEL_DATA=uint -DAS_INPUT_DATA=as_float -DAS_INPUT_DATA4=as_float4 -DAS_OUTPUT_DATA4=as_uint4 -DINTEL_SUB_GROUP_READ=intel_sub_group_block_read -DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4 -DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4"
extra_macro["conv_2d_c1_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DINPUT_BLOCK_SIZE=16 -DINPUT_CHANNEL=16 -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1"
extra_macro["conv_2d_c16_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DINPUT_BLOCK_SIZE=16 -DINPUT_CHANNEL=16 -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1"
extra_macro["depthwise_conv2d_subgroup_buf.cl"] = " -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1"
extra_macro["matmul_local_buf.cl"] = " -DOPWM=64 -DOPWN=128 -DCPWK=8 -DOPTM=4 -DOPTN=8"
extra_macro["pooling_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DSTRIDE_Y=2 -DSTRIDE_X=2 -DKERNEL_Y=4 -DKERNEL_X=4"
extra_macro["reduction_buf.cl"] = " -DOPERATE(a,b)=(a+b) -DVALUE=0"
extra_macro["reduction.cl"] = " -DOPERATE(a,b)=(a+b) -DVALUE=0"
extra_macro["unary_subgroup_buf.cl"] = " -DINTEL_DATA=uint -DAS_INPUT_DATA=as_float -DAS_INPUT_DATA4=as_float4 -DAS_OUTPUT_DATA4=as_uint4 -DINTEL_SUB_GROUP_READ=intel_sub_group_block_read -DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4 -DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4"
# 遍历当前目录的所有.cl文件
for filename in os.listdir(path):
if filename.endswith('.cl'):
source_file = os.path.join(path, filename)
with open(source_file, 'r') as file:
file_content = file.read()
with open('kernel.cl', 'w') as outfile:
outfile.write(file_content)
# 提取宏定义
macros_all = extract_macros(file_content)
# Compile with different macro values
operator_macro = {}
if filename == "binary_buf.cl" or filename == "binary.cl" or filename == "loop.cl" or filename == "binary_subgroup_buf.cl":
operator_macro = binaryvec_operator
elif filename == "loop_buf.cl":
operator_macro = binary_operator
elif filename == "unary_buf.cl" or filename == "unary.cl" or filename == "unary_subgroup_buf.cl":
operator_macro = unary_operator
if "subgroup" in filename and without_subgroup == 1:
continue
compile_with_macros(macros_all, operator_macro, extra_macro, filename, test_for_android)
if __name__ == "__main__":
main()

View File

@ -1,6 +1,7 @@
option(LLM_SUPPORT_VISION "Llm model support vision input." OFF)
option(LLM_SUPPORT_AUDIO "Llm model support audio input." OFF)
option(BUILD_MLS "Build PC Commandline." OFF)
option(LLM_USE_MINJA "Use minja to apply template" ON)
set(LLM_DEPS ${MNN_DEPS})
if (LLM_SUPPORT_VISION AND MNN_BUILD_OPENCV)
@ -22,7 +23,7 @@ endif()
include_directories(${CMAKE_CURRENT_LIST_DIR}/include/)
# source files
FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp ${CMAKE_CURRENT_LIST_DIR}/src/speculative_decoding/*.cpp)
FILE(GLOB_RECURSE SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*)
if (MNN_SEP_BUILD)
if (MNN_BUILD_SHARED_LIBS)
@ -37,12 +38,17 @@ if (MNN_SEP_BUILD)
else()
add_library(llm OBJECT ${SRCS})
endif()
if (LLM_USE_MINJA)
target_compile_options(llm PRIVATE -DLLM_USE_MINJA)
add_executable(apply_template ${CMAKE_CURRENT_LIST_DIR}/demo/apply_template.cpp)
target_link_libraries(apply_template ${LLM_DEPS})
endif()
if (LLM_SUPPORT_VISION AND MNN_BUILD_OPENCV)
target_compile_definitions(llm PRIVATE LLM_SUPPORT_VISION)
endif()
if (LLM_SUPPORT_AUDIO AND MNN_BUILD_AUDIO)
target_compile_definitions(llm PRIVATE LLM_SUPPORT_AUDIO)
add_definitions(-DLLM_SUPPORT_AUDIO)
endif()
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp)
@ -54,6 +60,7 @@ target_link_libraries(embedding_demo ${LLM_DEPS})
target_link_libraries(rollback_demo ${LLM_DEPS})
target_link_libraries(llm_bench ${LLM_DEPS})
if (BUILD_MLS)
set(CMAKE_OSX_DEPLOYMENT_TARGET "13.0" CACHE STRING "Minimum macOS version" FORCE)

View File

@ -166,41 +166,28 @@ static int serve(int argc, const char *argv[]) {
bool invalid_param{false};
std::string config_path{};
std::string arg{};
std::string model_name{};
if (argc < 3) {
print_usage();
return 1;
}
arg = argv[2];
if (arg.find('-') != 0) {
config_path = (fs::path(mls::FileUtils::GetBaseCacheDir()) / arg / "config.json").string();
}
for (int i = 2; i < argc; i++) {
arg = argv[i];
if (arg.find('-') != 0) {
model_name = arg;
config_path = (fs::path(mls::FileUtils::GetBaseCacheDir()) / arg / "config.json").string();
continue;
}
if (arg == "-c") {
if (++i >= argc) {
invalid_param = true;
break;
}
config_path = mls::FileUtils::ExpandTilde(argv[i]);
model_name = mls::FileUtils::GetFileName(fs::path(config_path).parent_path());
continue;
}
}
if (invalid_param) {
std::cerr << "Error: Missing value after -c option" << std::endl;
print_usage();
return 1;
}
mls::MlsServer server;
bool is_r1 = IsR1(config_path);
auto llm = create_and_prepare_llm(config_path.c_str(), !is_r1);
server.Start(model_name, llm.get(), is_r1);
server.Start(llm.get(), is_r1);
return 0;
}

View File

@ -21,96 +21,19 @@ std::string GetCurrentTimeAsString() {
return std::to_string(seconds);
}
// bool FromJson(const json& j, PromptItem& item) {
// if (!j.is_object()) {
// return false;
// }
// if (!j.contains("role") || !j["role"].is_string()) {
// return false;
// }
// if (!j.contains("content") || !j["content"].is_string()) {
// return false;
// }
// item.first = j["role"].get<std::string>(); // Role
// item.second = j["content"].get<std::string>(); // Content
// return true;
// }
std::string base64_decode(const std::string &ascdata) {
static const char b64_table[65] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static const char reverse_table[128] = {
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64,
64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64,
64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64
};
std::string retval;
const std::string::const_iterator last = ascdata.end();
int bits_collected = 0;
unsigned int accumulator = 0;
for (std::string::const_iterator i = ascdata.begin(); i != last; ++i) {
const int c = *i;
if (::std::isspace(c) || c == '=') {
// Skip whitespace and padding. Be liberal in what you accept.
continue;
}
if ((c > 127) || (c < 0) || (reverse_table[c] > 63)) {
MNN_ERROR("Base64 decode auth code failed.\n");
return "";
}
accumulator = (accumulator << 6) | reverse_table[c];
bits_collected += 6;
if (bits_collected >= 8) {
bits_collected -= 8;
retval += static_cast<char>((accumulator >> bits_collected) & 0xffu);
}
}
return retval;
}
static std::string SaveImageFromDataUrl(const std::string& data_url) {
auto comma = data_url.find(",");
std::string b64 = (comma != std::string::npos ? data_url.substr(comma + 1) : data_url);
auto bytes = base64_decode(b64);
std::string filename = "image_" + GetCurrentTimeAsString() + ".jpg";
std::ofstream ofs(filename, std::ios::binary);
ofs.write(reinterpret_cast<const char*>(bytes.data()), bytes.size());
ofs.close();
return filename;
}
bool FromJson(const json& j, PromptItem& item) {
if (!j.is_object() || !j.contains("role")) return false;
item.first = j["role"].get<std::string>();
if (!j.contains("content")) return false;
// Handle text or array content
if (j["content"].is_string()) {
item.second = j["content"].get<std::string>();
} else if (j["content"].is_array()) {
std::string combined;
for (const auto& elem : j["content"]) {
if (elem.is_object() && elem.value("type", "") == "image_url"
&& elem.contains("image_url")
&& elem["image_url"].contains("url")) {
std::string data_url = elem["image_url"]["url"].get<std::string>();
std::string path = SaveImageFromDataUrl(data_url);
combined += "<img>" + path + "</img>";
}
// other types can be handled here
}
item.second = combined;
} else {
return false;
if (!j.is_object()) {
return false;
}
if (!j.contains("role") || !j["role"].is_string()) {
return false;
}
if (!j.contains("content") || !j["content"].is_string()) {
return false;
}
item.first = j["role"].get<std::string>(); // Role
item.second = j["content"].get<std::string>(); // Content
return true;
}
@ -229,12 +152,7 @@ void AllowCors(httplib::Response& res) {
res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
}
std::vector<std::string> MlsServer::GetLocalModels() {
return {current_model_name_};
}
void MlsServer::Start(const std::string& model_name, MNN::Transformer::Llm* llm, bool is_r1) {
this->current_model_name_ = model_name;
void MlsServer::Start(MNN::Transformer::Llm* llm, bool is_r1) {
this->is_r1_ = is_r1;
// Create a server instance
httplib::Server server;
@ -244,146 +162,100 @@ void MlsServer::Start(const std::string& model_name, MNN::Transformer::Llm* llm,
AllowCors(res);
res.set_content(html_content, "text/html");
});
// Add OpenAI compatible /v1 routes for all APIs
auto handleChatCompletions = [&](const httplib::Request &req, httplib::Response &res, bool is_v1) {
if (!json::accept(req.body)) {
json err;
err["error"] = "Invalid JSON in request body.";
res.status = 400;
res.set_content(err.dump(), "application/json");
return;
}
json request_json = json::parse(req.body, nullptr, false);
json messages = request_json["messages"];
std::cout<<"received messages:"<<messages.dump(0)<<std::endl;
std::string model = request_json.value("model", "undefined-model");
bool stream = request_json.value("stream", false);
if (!stream) {
Answer(llm, messages, [&res, model](const std::string& answer) {
json response_json = {
{"id", "chatcmpl" + GetCurrentTimeAsString()},
{"object", "chat.completion"},
{"created", static_cast<int>(time(nullptr))},
{"model", model},
{
"choices", json::array({
{
{"index", 0},
{
"message", {
{"role", "assistant"},
{"content", answer}
}
},
{"finish_reason", "stop"}
}
})
},
{
"usage", {
{"prompt_tokens", 10},
{"completion_tokens", 7},
{"total_tokens", 17}
}
}
};
res.set_content(response_json.dump(), "application/json");
});
return;
}
res.set_header("Content-Type", "text/event-stream");
res.set_header("Cache-Control", "no-cache");
res.set_header("Connection", "keep-alive");
res.set_chunked_content_provider(
"text/event-stream",
[llm, messages, model, this](size_t /*offset*/, httplib::DataSink &sink) {
auto sse_callback = [&, this](const std::string &partial_text, bool end) {
std::string finish_reason = end ? "stop" : "";
json sse_json = {
{"id", "chatcmpl-" + GetCurrentTimeAsString()},
{"object", "chat.completion.chunk"},
{"created", static_cast<int>(std::time(nullptr))},
{"model", model},
{"choices", json::array({
{
{"delta", {{"content", partial_text}}},
{"index", 0},
{"finish_reason", finish_reason}
}
})}
};
std::string chunk_str = "data: " + sse_json.dump() + "\n\n";
sink.os.write(chunk_str.c_str(), chunk_str.size());
sink.os.flush();
};
AnswerStreaming(llm, messages, sse_callback);
std::string done_str = "data: [DONE]\n\n";
sink.os.write(done_str.c_str(), done_str.size());
sink.os.flush();
sink.done();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
return false;
}
);
};
// Route registrations
server.Get("/v1/", [this](const httplib::Request& req, httplib::Response& res) {
AllowCors(res);
res.set_content(html_content, "text/html");
});
server.Post("/reset", [&](const httplib::Request &req, httplib::Response &res) {
printf("POST /reset\n");
llm->reset();
res.set_content("{\"status\": \"ok\"}", "application/json");
printf("POST /reset\n");
AllowCors(res);
llm->reset();
res.set_content("{\"status\": \"ok\"}", "application/json");
});
server.Post("/v1/reset", [&](const httplib::Request &req, httplib::Response &res) {
printf("POST /v1/reset\n");
llm->reset();
res.set_content("{\"status\": \"ok\"}", "application/json");
});
server.Options("/chat/completions", [](const httplib::Request& /*req*/, httplib::Response& res) {
AllowCors(res);
res.status = 200;
});
server.Options("/v1/chat/completions", [](const httplib::Request& /*req*/, httplib::Response& res) {
AllowCors(res);
res.status = 200;
});
server.Get("/models/list", [this](const httplib::Request& req, httplib::Response& res) {
AllowCors(res);
std::vector<std::string> model_names = GetLocalModels();
json response = json::array();
for (const auto& name : model_names) {
response.push_back(name);
}
res.set_content(response.dump(), "application/json");
});
server.Get("/v1/models", [this](const httplib::Request& req, httplib::Response& res) {
AllowCors(res);
std::vector<std::string> model_names = GetLocalModels();
json response = json::array();
for (const auto& name : model_names) {
response.push_back(name);
}
res.set_content(response.dump(), "application/json");
});
server.Post("/chat/completions", [&](const httplib::Request &req, httplib::Response &res) {
std::cout << "POST /chat/completions, handled by thread: "
<< std::this_thread::get_id() << std::endl;
handleChatCompletions(req, res, false);
AllowCors(res);
if (!json::accept(req.body)) {
json err;
err["error"] = "Invalid JSON in request body.";
res.status = 400;
res.set_content(err.dump(), "application/json");
return;
}
json request_json = json::parse(req.body, nullptr, false);
json messages = request_json["messages"];
std::cout<<"received messages:"<<messages.dump(0)<<std::endl;
std::string model = request_json.value("model", "undefined-model");
bool stream = request_json.value("stream", false);
if (!stream) {
Answer(llm, messages, [&res, model](const std::string& answer) {
json response_json = {
{"id", "chatcmpl" + GetCurrentTimeAsString()},
{"object", "chat.completion"},
{"created", static_cast<int>(time(nullptr))},
{"model", model},
{
"choices", json::array({
{
{"index", 0},
{
"message", {
{"role", "assistant"},
{"content", answer}
}
},
{"finish_reason", "stop"}
}
})
},
{
"usage", {
{"prompt_tokens", 10},
{"completion_tokens", 7},
{"total_tokens", 17}
}
}
};
res.set_content(response_json.dump(), "application/json");
});
return;
}
res.set_header("Content-Type", "text/event-stream");
res.set_header("Cache-Control", "no-cache");
res.set_header("Connection", "keep-alive");
res.set_chunked_content_provider(
"text/event-stream",
[llm, messages, model, this](size_t /*offset*/, httplib::DataSink &sink) {
auto sse_callback = [&, this](const std::string &partial_text, bool end) {
std::string finish_reason = end ? "stop" : "";
json sse_json = {
{"id", "chatcmpl-" + GetCurrentTimeAsString()},
{"object", "chat.completion.chunk"},
{"created", static_cast<int>(std::time(nullptr))},
{"model", model},
{"choices", json::array({
{
{"delta", {{"content", partial_text}}},
{"index", 0},
{"finish_reason", finish_reason}
}
})}
};
std::string chunk_str = "data: " + sse_json.dump() + "\n\n";
sink.os.write(chunk_str.c_str(), chunk_str.size());
sink.os.flush();
};
AnswerStreaming(llm, messages, sse_callback);
std::string done_str = "data: [DONE]\n\n";
sink.os.write(done_str.c_str(), done_str.size());
sink.os.flush();
sink.done();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
return false;
}
);
});
server.Post("/v1/chat/completions", [&](const httplib::Request &req, httplib::Response &res) {
std::cout << "POST /v1/chat/completions, handled by thread: "
<< std::this_thread::get_id() << std::endl;
handleChatCompletions(req, res, true);
});
// Start the server on port
std::cout << "Starting server on http://localhost:9090\n";
if (!server.listen("0.0.0.0", 9090)) {

View File

@ -258,11 +258,9 @@ class MlsServer {
</body>
</html>
)""";
void Start(const std::string& modle_name, MNN::Transformer::Llm* llm, bool is_r1);
void Start(MNN::Transformer::Llm* llm, bool is_r1);
bool is_r1_{false};
private:
std::string current_model_name_{};
std::vector<std::string> GetLocalModels();
void Answer(MNN::Transformer::Llm* llm, const json &messages, std::function<void(const std::string&)> on_result);
void AnswerStreaming(MNN::Transformer::Llm* llm,
const json& messages,

View File

@ -0,0 +1,193 @@
#include <MNN/MNNDefine.h>
#include "../src/minja/chat_template.hpp"
#include <rapidjson/document.h>
#include <rapidjson/prettywriter.h>
#include <fstream>
#include <iostream>
#include <sstream>
static int test(const char* testjson) {
rapidjson::Document document;
std::ifstream inputFs(testjson);
std::ostringstream osString;
if (inputFs.fail()) {
MNN_ERROR("Open %s error\n", testjson);
return 0;
}
osString << inputFs.rdbuf();
document.Parse(osString.str().c_str());
if (document.HasParseError() || (!document.IsArray())) {
MNN_ERROR("Invalid json\n");
return 0;
}
int pos = 0;
for (auto& iter : document.GetArray()) {
std::string res = iter["res"].GetString();
std::string chatTemplate = iter["chat_template"].GetString();
std::string bos;
std::string eos;
if (iter.HasMember("bos")) {
bos = iter["bos"].GetString();
}
if (iter.HasMember("eos")) {
eos = iter["eos"].GetString();
}
minja::chat_template tmpl(chatTemplate, bos, eos);
minja::chat_template_inputs inputs;
inputs.messages.CopyFrom(iter["messages"], inputs.messages.GetAllocator());
if (iter.HasMember("extras")) {
inputs.extra_context.CopyFrom(iter["extras"], inputs.extra_context.GetAllocator());
}
inputs.add_generation_prompt = true;
auto newres = tmpl.apply(inputs);
if (res != newres) {
MNN_ERROR("Error for %d template\n", pos);
MNN_ERROR("Origin:\n%s\n", res.c_str());
MNN_ERROR("Compute:\n%s\n", newres.c_str());
return 0;
}
pos++;
}
MNN_PRINT("Test %d template, All Right\n", pos);
return 0;
}
int main(int argc, const char* argv[]) {
if (argc < 2) {
MNN_ERROR("Usage: ./apply_template token_config.json \n");
MNN_ERROR("Or \n");
MNN_ERROR("Usage: ./apply_template test.json 1\n");
return 0;
}
if (argc >= 3) {
MNN_PRINT("Test %s\n", argv[1]);
test(argv[1]);
return 0;
}
rapidjson::Document resDocument;
{
// Load origin result
std::ifstream inputFs("result.json");
bool valid = false;
if (!inputFs.fail()) {
std::ostringstream osString;
osString << inputFs.rdbuf();
resDocument.Parse(osString.str().c_str());
if (resDocument.HasParseError()) {
MNN_ERROR("Invalid json\n");
} else {
valid = true;
MNN_PRINT("Has result.json, append it\n");
}
}
if (!valid) {
resDocument.SetArray();
MNN_PRINT("Create new result.json\n");
}
}
for (int i=1; i<argc; ++i) {
auto tokenConfigPath = argv[1];
FUNC_PRINT_ALL(tokenConfigPath, s);
rapidjson::Document document;
std::ifstream inputFs(tokenConfigPath);
std::ostringstream osString;
if (inputFs.fail()) {
MNN_ERROR("Open File error\n");
return 0;
}
osString << inputFs.rdbuf();
document.Parse(osString.str().c_str());
if (document.HasParseError()) {
MNN_ERROR("Invalid json\n");
return 0;
}
std::string bosToken, eosToken;
auto loadtoken = [](const rapidjson::GenericValue<rapidjson::UTF8<>>& value, std::string& dst) {
if (value.IsString()) {
dst = value.GetString();
return;
}
if (value.IsObject()) {
if (value.HasMember("content") && value["content"].IsString()) {
dst = value["content"].GetString();
return;
}
}
};
if (document.HasMember("bos_token")) {
loadtoken(document["bos_token"], bosToken);
}
if (document.HasMember("eos_token")) {
loadtoken(document["eos_token"], eosToken);
}
std::string templateChat;
if (document.HasMember("chat_template")) {
templateChat = document["chat_template"].GetString();
}
if (templateChat.empty()) {
MNN_ERROR("Invalid json with no chat_template\n");
return 0;
}
minja::chat_template tmpl(templateChat, bosToken, eosToken);
minja::chat_template_inputs inputs;
inputs.extra_context.SetObject();
inputs.extra_context.GetObject().AddMember("enable_thinking", false, inputs.extra_context.GetAllocator());
inputs.messages.Parse(R"([
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is 8 * 12."
},
{
"role": "assistant",
"content": "96."
},
{
"role": "user",
"content": "What is 9 * 8?"
}
])");
inputs.add_generation_prompt = true;
auto res = tmpl.apply(inputs);
MNN_PRINT("%s", res.c_str());
// Write result
rapidjson::Value v;
v.SetObject();
rapidjson::Value messages;
messages.CopyFrom(inputs.messages, resDocument.GetAllocator());
rapidjson::Value extras;
extras.CopyFrom(inputs.extra_context, resDocument.GetAllocator());
v.AddMember("messages", messages, resDocument.GetAllocator());
v.AddMember("extras", extras, resDocument.GetAllocator());
{
rapidjson::Value tv;
tv.SetString(templateChat.c_str(), resDocument.GetAllocator());
v.AddMember("chat_template", tv, resDocument.GetAllocator());
}
if (!bosToken.empty()) {
rapidjson::Value tv;
tv.SetString(bosToken.c_str(), resDocument.GetAllocator());
v.AddMember("bos", tv, resDocument.GetAllocator());
}
if (!eosToken.empty()) {
rapidjson::Value tv;
tv.SetString(eosToken.c_str(), resDocument.GetAllocator());
v.AddMember("eos", tv, resDocument.GetAllocator());
}
{
rapidjson::Value tv;
tv.SetString(res.c_str(), resDocument.GetAllocator());
v.AddMember("res", tv, resDocument.GetAllocator());
}
resDocument.GetArray().PushBack(v, resDocument.GetAllocator());
}
rapidjson::StringBuffer buf;
rapidjson::PrettyWriter<rapidjson::StringBuffer> bufwriter(buf);
resDocument.Accept(bufwriter);
std::ofstream os("result.json");
os << buf.GetString();
return 0;
}

View File

@ -13,6 +13,9 @@
#include <sstream>
#include <stdlib.h>
#include <initializer_list>
#ifdef LLM_SUPPORT_AUDIO
#include "audio/audio.hpp"
#endif
using namespace MNN::Transformer;
static void tuning_prepare(Llm* llm) {
@ -77,6 +80,19 @@ static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_
if (max_token_number > 0) {
llm->set_config("{\"max_new_tokens\":1}");
}
#ifdef LLM_SUPPORT_AUDIO
std::vector<float> waveform;
llm->setWavformCallback([&](const float* ptr, size_t size, bool last_chunk) {
waveform.reserve(waveform.size() + size);
waveform.insert(waveform.end(), ptr, ptr + size);
if (last_chunk) {
auto waveform_var = MNN::Express::_Const(waveform.data(), {(int)waveform.size()}, MNN::Express::NCHW, halide_type_of<float>());
MNN::AUDIO::save("output.wav", waveform_var, 24000);
waveform.clear();
}
return true;
});
#endif
for (int i = 0; i < prompts.size(); i++) {
const auto& prompt = prompts[i];
@ -96,7 +112,6 @@ static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_
} else {
llm->response(prompt);
}
llm->reset();
prompt_len += context->prompt_len;
decode_len += context->gen_seq_len;
vision_time += context->vision_us;
@ -105,6 +120,8 @@ static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_
decode_time += context->decode_us;
sample_time += context->sample_us;
}
llm->generateWavform();
float vision_s = vision_time / 1e6;
float audio_s = audio_time / 1e6;
float prefill_s = prefill_time / 1e6;
@ -170,12 +187,20 @@ static int eval(Llm* llm, std::string prompt_file, int max_token_number) {
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
//#define LLM_DEMO_ONELINE
#ifdef LLM_DEMO_ONELINE
std::ostringstream tempOs;
tempOs << prompt_fs.rdbuf();
prompt = tempOs.str();
prompts = {prompt};
#else
while (std::getline(prompt_fs, prompt)) {
if (prompt.back() == '\r') {
prompt.pop_back();
}
prompts.push_back(prompt);
}
#endif
prompt_fs.close();
if (prompts.empty()) {
return 1;
@ -240,6 +265,16 @@ int main(int argc, const char* argv[]) {
std::istringstream os(argv[3]);
os >> max_token_number;
}
if (argc >= 5) {
MNN_PRINT("Set not thinking, only valid for Qwen3\n");
llm->set_config(R"({
"jinja": {
"context": {
"enable_thinking":false
}
}
})");
}
std::string prompt_file = argv[2];
return eval(llm.get(), prompt_file, max_token_number);
}

View File

@ -81,7 +81,7 @@ public:
virtual Express::VARP embedding(const std::vector<int>& input_ids);
Express::VARP forward(const std::vector<int>& input_ids, bool is_prefill = true);
Express::VARP forward(MNN::Express::VARP input_embeds);
virtual Express::VARP forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos);
virtual std::vector<Express::VARP> forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos);
virtual int sample(Express::VARP logits, int offset = 0, int size = 0);
void reset();
void tuning(TuneType type, std::vector<int> candidates);
@ -130,8 +130,9 @@ protected:
std::vector<std::shared_ptr<Express::Module>> mModules, mPrefillModules, mDecodeModules, mCurrentModules;
const Express::Module* mBaseModule = nullptr;
Express::VARP inputsEmbeds, attentionMask, positionIds;
std::vector<Express::VARP> mInputsEmbedsVarVec, mAttentionMaskVarVec, mPositionIdsVarVec;
std::vector<Express::VARP> mAttentionMaskVarVec, mPositionIdsVarVec;
Express::VARP logitsAllIdx, logitsLastIdx;
int mSeqLenIndex = 0;
private:
// decoding phase will use speculative decoding
void speculativeGenerate(int max_token);

View File

@ -10,7 +10,6 @@
#include <iostream>
#include <sstream>
#include <unordered_set>
#include <MNN/AutoTime.hpp>
#include <MNN/expr/ExecutorScope.hpp>
#include "cpp/ExprDebug.hpp"
@ -261,7 +260,6 @@ void Llm::load() {
logitsLastIdx = _var<int>({-1}, {1});
logitsAllIdx = _var<int>({0}, {1});
// index match with seq_len
mInputsEmbedsVarVec.resize(decode_type_num);
mAttentionMaskVarVec.resize(decode_type_num);
mPositionIdsVarVec.resize(decode_type_num);
for(int i = 0; i < decode_type_num; i++) {
@ -281,7 +279,6 @@ void Llm::load() {
}
mPositionIdsVarVec[i] = _Input({index}, NCHW, halide_type_of<int>());
mInputsEmbedsVarVec[i] = _Input({index, 1, mConfig->hidden_size()}, NCHW);
}
}
@ -345,6 +342,8 @@ void Llm::tuning(TuneType type, std::vector<int> candidates) {
}
mCurrentModules = mDecodeModules;
int decode_seq = 1;
// Set to decode mode
mContext->gen_seq_len = 1;
if(mLookAhead) {
// start autoregressive decoding
std::vector<int> input_ids = {0};
@ -402,9 +401,10 @@ void Llm::setKVCacheInfo(size_t add, size_t remove, int* reserve, int n_reserve)
mMeta->add = add;
}
Express::VARP Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) {
std::vector<Express::VARP> Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) {
VARP logits;
Express::VARP logitsIndex;
bool inDecode = mContext->gen_seq_len > 0;
// TODO : to be improved
if (mConfig->all_logits() || (mLookAhead && mCurrentModules.size() > 1)) {
logitsIndex = logitsAllIdx;
@ -420,7 +420,7 @@ Express::VARP Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Exp
outputs = mCurrentModules.back()->onForward({hiddenState, mask, inputPos, logitsIndex});
}
if (outputs.empty()) {
return nullptr;
return outputs;
}
logits = outputs[0];
@ -477,9 +477,8 @@ Express::VARP Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Exp
}
}
#endif
mMeta->sync();
return logits;
return outputs;
}
VARP Llm::forward(const std::vector<int>& input_ids, bool is_prefill) {
@ -488,11 +487,15 @@ VARP Llm::forward(const std::vector<int>& input_ids, bool is_prefill) {
}
VARP Llm::forward(MNN::Express::VARP input_embeds) {
int seq_len = input_embeds->getInfo()->dim[0];
int seq_len = input_embeds->getInfo()->dim[mSeqLenIndex];
mMeta->add = seq_len;
auto attention_mask = gen_attention_mask(seq_len);
auto position_ids = gen_position_ids(seq_len);
auto logits = forwardRaw(input_embeds, attention_mask, position_ids);
auto out = forwardRaw(input_embeds, attention_mask, position_ids);
if (out.empty()) {
return nullptr;
}
auto logits = out[0];
mContext->all_seq_len += seq_len;
mContext->gen_seq_len++;
return logits;
@ -511,6 +514,7 @@ void Llm::reset() {
mContext->output_tokens.clear();
mContext->history_tokens.clear();
mContext->all_seq_len = 0;
mMeta->remove = mMeta->previous;
}
void Llm::generate_init(std::ostream* os, const char* end_with) {
@ -728,16 +732,7 @@ VARP Llm::embedding(const std::vector<int>& input_ids) {
AUTOTIME;
int hidden_size = mConfig->hidden_size();
int seq_len = static_cast<int>(input_ids.size());
if (mInputsEmbedsVarVec.size() > 0) {
if(seq_len == 1) {
mDiskEmbedding->embedding(input_ids, mInputsEmbedsVarVec[0]->writeMap<float>());
return mInputsEmbedsVarVec[0];
}
if(mInputsEmbedsVarVec.size() > 1 && seq_len == mDraftLength) {
mDiskEmbedding->embedding(input_ids, mInputsEmbedsVarVec[1]->writeMap<float>());
return mInputsEmbedsVarVec[1];
}
}
VARP res = _Input({seq_len, 1, hidden_size}, NCHW);
// disk embedding to save memory
mDiskEmbedding->embedding(input_ids, res->writeMap<float>());
@ -755,27 +750,24 @@ std::string Llm::tokenizer_decode(int id) {
}
VARP Llm::gen_attention_mask(int seq_len) {
bool useSquareMask = false;
int kv_seq_len = mContext->all_seq_len + seq_len;
if (mConfig->attention_mask() == "float") {
// currently only metal supoort square mask
useSquareMask = (mConfig->backend_type() == "metal");
if(useSquareMask) {
kv_seq_len = seq_len;
}
if(seq_len == 1) {
return mAttentionMaskVarVec[0];
}
if (useSquareMask && mAttentionMaskVarVec.size() > 1 && seq_len == mDraftLength) {
return mAttentionMaskVarVec[1];
// Use square mask
kv_seq_len = seq_len;
if (mAttentionMaskVarVec.size() > 0) {
if(seq_len == 1) {
return mAttentionMaskVarVec[0];
}
if (mAttentionMaskVarVec.size() > 1 && seq_len == mDraftLength) {
return mAttentionMaskVarVec[1];
}
}
attentionMask = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of<float>());
auto ptr = attentionMask->writeMap<float>();
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < kv_seq_len; j++) {
int row = useSquareMask ? i : i + mContext->all_seq_len;
ptr[kv_seq_len * i + j] = (j > row) * std::numeric_limits<float>::lowest();
ptr[kv_seq_len * i + j] = (j > i) * std::numeric_limits<float>::lowest();
}
}
return attentionMask;

View File

@ -0,0 +1,746 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#ifdef LLM_USE_MINJA
#include "minja.hpp"
#include "chat_template.hpp"
namespace minja {
// Helper to convert Value to string
static std::string valueToString(const rapidjson::Value& val) {
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
val.Accept(writer);
return buffer.GetString();
}
chat_template::chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
{
template_root_ = minja::parse(source_, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
#define MINJA_ADD_TEST
#ifdef MINJA_ADD_TEST
auto contains = [](const std::string & haystack, const std::string & needle) {
return haystack.find(needle) != std::string::npos;
};
// This entire block needs to be refactored to use rapidjson.
// This is a significant change due to how objects and arrays are constructed.
// I will need a Document (with its allocator) for each JSON structure.
Document d_render_test; // Document for constructing test JSONs
auto& alloc = d_render_test.GetAllocator();
const std::string user_needle = "<User Needle>";
const std::string sys_needle = "<System Needle>";
// const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
rapidjson::Value dummy_str_user_msg(rapidjson::kObjectType);
dummy_str_user_msg.AddMember("role", "user", alloc);
dummy_str_user_msg.AddMember("content", rapidjson::StringRef(user_needle.c_str()), alloc);
// const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
rapidjson::Value dummy_typed_user_msg(rapidjson::kObjectType);
dummy_typed_user_msg.AddMember("role", "user", alloc);
{
rapidjson::Value content_array(rapidjson::kArrayType);
rapidjson::Value content_item(rapidjson::kObjectType);
content_item.AddMember("type", "text", alloc);
content_item.AddMember("text", rapidjson::StringRef(user_needle.c_str()), alloc);
content_array.PushBack(content_item, alloc);
dummy_typed_user_msg.AddMember("content", content_array, alloc);
}
rapidjson::Value dummy_str_user_msg_copy1;
dummy_str_user_msg_copy1.CopyFrom(dummy_str_user_msg, alloc);
rapidjson::Value messages_for_render1(rapidjson::kArrayType);
messages_for_render1.PushBack(dummy_str_user_msg_copy1, alloc);
rapidjson::Value no_tools(rapidjson::kArrayType); // Assuming empty array for no tools
// For capability detection, polyfills are off, so copy is fine.
rapidjson::Value messages_typed_content_test1(rapidjson::kArrayType);
dummy_str_user_msg_copy1.CopyFrom(dummy_str_user_msg, alloc);
messages_typed_content_test1.PushBack(dummy_str_user_msg_copy1, alloc);
rapidjson::Value no_tools_copy1; no_tools_copy1.CopyFrom(no_tools, alloc);
rapidjson::Value dummy_typed_user_msg_copy1;
dummy_typed_user_msg_copy1.CopyFrom(dummy_typed_user_msg, alloc);
rapidjson::Value messages_typed_content_test2(rapidjson::kArrayType);
messages_typed_content_test2.PushBack(dummy_typed_user_msg_copy1, alloc);
rapidjson::Value no_tools_copy2; no_tools_copy2.CopyFrom(no_tools, alloc);
caps_.requires_typed_content =
!contains(try_raw_render(messages_typed_content_test1, no_tools_copy1, false, alloc), user_needle) &&
contains(try_raw_render(messages_typed_content_test2, no_tools_copy2, false, alloc), user_needle);
// const auto dummy_user_msg = caps_.requires_typed_content ? dummy_typed_user_msg : dummy_str_user_msg;
rapidjson::Value dummy_user_msg(rapidjson::kObjectType);
if (caps_.requires_typed_content) {
dummy_user_msg.CopyFrom(dummy_typed_user_msg, alloc);
} else {
dummy_user_msg.CopyFrom(dummy_str_user_msg, alloc);
}
// const json needle_system_msg = {
// {"role", "system"},
// {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
// };
rapidjson::Value needle_system_msg(rapidjson::kObjectType);
needle_system_msg.AddMember("role", "system", alloc);
if (caps_.requires_typed_content) {
rapidjson::Value content_array_sys(rapidjson::kArrayType);
rapidjson::Value content_item_sys(rapidjson::kObjectType);
content_item_sys.AddMember("type", "text", alloc);
content_item_sys.AddMember("text", rapidjson::StringRef(sys_needle.c_str()), alloc);
content_array_sys.PushBack(content_item_sys, alloc);
needle_system_msg.AddMember("content", content_array_sys, alloc);
} else {
needle_system_msg.AddMember("content", rapidjson::StringRef(sys_needle.c_str()), alloc);
}
// caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
rapidjson::Value messages_for_sys_role_test(rapidjson::kArrayType);
rapidjson::Value needle_system_msg_copy; needle_system_msg_copy.CopyFrom(needle_system_msg, alloc);
rapidjson::Value dummy_user_msg_copy2; dummy_user_msg_copy2.CopyFrom(dummy_user_msg, alloc);
messages_for_sys_role_test.PushBack(needle_system_msg_copy, alloc);
messages_for_sys_role_test.PushBack(dummy_user_msg_copy2, alloc);
rapidjson::Value no_tools_copy3; no_tools_copy3.CopyFrom(no_tools, alloc);
caps_.supports_system_role = contains(try_raw_render(messages_for_sys_role_test, no_tools_copy3, false, alloc), sys_needle);
// auto out = try_raw_render(json::array({dummy_user_msg}), json::array({...}), false);
rapidjson::Value messages_for_tools_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy3; dummy_user_msg_copy3.CopyFrom(dummy_user_msg, alloc);
messages_for_tools_test.PushBack(dummy_user_msg_copy3, alloc);
rapidjson::Value tools_for_test(rapidjson::kArrayType);
rapidjson::Value tool_def(rapidjson::kObjectType);
tool_def.AddMember("name", "some_tool", alloc);
tool_def.AddMember("type", "function", alloc);
rapidjson::Value function_def(rapidjson::kObjectType);
function_def.AddMember("name", "some_tool", alloc);
function_def.AddMember("description", "Some tool.", alloc);
rapidjson::Value params_def(rapidjson::kObjectType);
params_def.AddMember("type", "object", alloc);
rapidjson::Value props_def(rapidjson::kObjectType);
rapidjson::Value arg_def(rapidjson::kObjectType);
arg_def.AddMember("type", "string", alloc);
arg_def.AddMember("description", "Some argument.", alloc);
props_def.AddMember("arg", arg_def, alloc);
params_def.AddMember("properties", props_def, alloc);
rapidjson::Value required_arr(rapidjson::kArrayType);
required_arr.PushBack("arg", alloc);
params_def.AddMember("required", required_arr, alloc);
function_def.AddMember("parameters", params_def, alloc);
tool_def.AddMember("function", function_def, alloc);
tools_for_test.PushBack(tool_def, alloc);
std::string out_tools_test = try_raw_render(tools_for_test, tools_for_test, false, alloc);
caps_.supports_tools = contains(out_tools_test, "some_tool");
// auto make_tool_calls_msg = [&](const json & tool_calls) { ... }
auto make_tool_calls_msg_rj = [&](rapidjson::Value& tool_calls_val, rapidjson::Document::AllocatorType& allocator_func) {
rapidjson::Value msg(rapidjson::kObjectType);
msg.AddMember("role", "assistant", allocator_func);
msg.AddMember("content", rapidjson::Value(rapidjson::kNullType), allocator_func);
msg.AddMember("tool_calls", tool_calls_val, allocator_func); // tool_calls_val is already using alloc from caller
return msg;
};
// auto make_tool_call = [](const std::string & tool_name, const json & arguments) { ... }
auto make_tool_call_rj = [&](const char* tool_name_str, rapidjson::Value& arguments_val, rapidjson::Document::AllocatorType& allocator_func) {
rapidjson::Value tc(rapidjson::kObjectType);
tc.AddMember("id", "call_1___", allocator_func);
tc.AddMember("type", "function", allocator_func);
rapidjson::Value func(rapidjson::kObjectType);
func.AddMember("arguments", arguments_val, allocator_func); // arguments_val is already using alloc from caller
func.AddMember("name", rapidjson::StringRef(tool_name_str), allocator_func);
tc.AddMember("function", func, allocator_func);
return tc;
};
// const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
rapidjson::Value dummy_args_obj_rj(rapidjson::kObjectType);
dummy_args_obj_rj.AddMember("argument_needle", "print('Hello, World!')", alloc);
// Convert dummy_args_obj_rj to string for the first test
rapidjson::StringBuffer buffer_args_str;
rapidjson::Writer<rapidjson::StringBuffer> writer_args_str(buffer_args_str);
dummy_args_obj_rj.Accept(writer_args_str);
std::string dummy_args_obj_as_string = buffer_args_str.GetString();
rapidjson::Value dummy_args_str_val(dummy_args_obj_as_string.c_str(), alloc);
// out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})) }), {}, false);
rapidjson::Value messages_for_tool_call_str_args_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy4; dummy_user_msg_copy4.CopyFrom(dummy_user_msg, alloc);
messages_for_tool_call_str_args_test.PushBack(dummy_user_msg_copy4, alloc);
rapidjson::Value tool_calls_array1(rapidjson::kArrayType);
rapidjson::Value tc1_args_str; tc1_args_str.CopyFrom(dummy_args_str_val, alloc); // Already a string value
std::string ipython = "ipython";
tool_calls_array1.PushBack(make_tool_call_rj(ipython.c_str(), tc1_args_str, alloc), alloc);
rapidjson::Value tool_calls_msg1 = make_tool_calls_msg_rj(tool_calls_array1, alloc);
messages_for_tool_call_str_args_test.PushBack(tool_calls_msg1, alloc);
rapidjson::Value no_tools_copy4; no_tools_copy4.CopyFrom(no_tools, alloc);
std::string out_tool_call_str_args = try_raw_render(messages_for_tool_call_str_args_test, no_tools_copy4, false, alloc);
bool tool_call_renders_str_arguments = contains(out_tool_call_str_args, "\"argument_needle\":") || contains(out_tool_call_str_args, "'argument_needle':");
// out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})) }), {}, false);
rapidjson::Value messages_for_tool_call_obj_args_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy5; dummy_user_msg_copy5.CopyFrom(dummy_user_msg, alloc);
messages_for_tool_call_obj_args_test.PushBack(dummy_user_msg_copy5, alloc);
rapidjson::Value tool_calls_array2(rapidjson::kArrayType);
rapidjson::Value tc1_args_obj; tc1_args_obj.CopyFrom(dummy_args_obj_rj, alloc);
tool_calls_array2.PushBack(make_tool_call_rj("ipython", tc1_args_obj, alloc), alloc);
rapidjson::Value tool_calls_msg2 = make_tool_calls_msg_rj(tool_calls_array2, alloc);
messages_for_tool_call_obj_args_test.PushBack(tool_calls_msg2, alloc);
rapidjson::Value no_tools_copy5; no_tools_copy5.CopyFrom(no_tools, alloc);
std::string out_tool_call_obj_args = try_raw_render(messages_for_tool_call_obj_args_test, no_tools_copy5, false, alloc);
bool tool_call_renders_obj_arguments = contains(out_tool_call_obj_args, "\"argument_needle\":") || contains(out_tool_call_obj_args, "'argument_needle':");
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
// auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
rapidjson::Value messages_for_empty_content_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy6; dummy_user_msg_copy6.CopyFrom(dummy_user_msg, alloc);
messages_for_empty_content_test.PushBack(dummy_user_msg_copy6, alloc);
rapidjson::Value assistant_msg_empty_content(rapidjson::kObjectType);
assistant_msg_empty_content.AddMember("role", "assistant", alloc);
assistant_msg_empty_content.AddMember("content", "", alloc);
messages_for_empty_content_test.PushBack(assistant_msg_empty_content, alloc);
rapidjson::Value no_tools_copy6; no_tools_copy6.CopyFrom(no_tools, alloc);
std::string out_empty_content = try_raw_render(messages_for_empty_content_test, no_tools_copy6, false, alloc);
// auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
rapidjson::Value messages_for_null_content_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy7; dummy_user_msg_copy7.CopyFrom(dummy_user_msg, alloc);
messages_for_null_content_test.PushBack(dummy_user_msg_copy7, alloc);
rapidjson::Value assistant_msg_null_content(rapidjson::kObjectType);
assistant_msg_null_content.AddMember("role", "assistant", alloc);
assistant_msg_null_content.AddMember("content", rapidjson::Value(rapidjson::kNullType), alloc);
messages_for_null_content_test.PushBack(assistant_msg_null_content, alloc);
rapidjson::Value no_tools_copy7; no_tools_copy7.CopyFrom(no_tools, alloc);
std::string out_null_content = try_raw_render(messages_for_null_content_test, no_tools_copy7, false, alloc);
caps_.requires_non_null_content = contains(out_empty_content, user_needle) && !contains(out_null_content, user_needle);
if (caps_.supports_tool_calls) {
// auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
rapidjson::Value dummy_args_for_parallel_test;
if (caps_.requires_object_arguments) {
dummy_args_for_parallel_test.CopyFrom(dummy_args_obj_rj, alloc);
} else {
// This was already created: dummy_args_str_val (string version of dummy_args_obj_rj)
dummy_args_for_parallel_test.CopyFrom(dummy_args_str_val, alloc);
}
// auto tc1 = make_tool_call("test_tool1", dummy_args);
// auto tc2 = make_tool_call("test_tool2", dummy_args);
rapidjson::Value dummy_args_tc1; dummy_args_tc1.CopyFrom(dummy_args_for_parallel_test, alloc);
rapidjson::Value tc1 = make_tool_call_rj("test_tool1", dummy_args_tc1, alloc);
rapidjson::Value dummy_args_tc2; dummy_args_tc2.CopyFrom(dummy_args_for_parallel_test, alloc);
rapidjson::Value tc2 = make_tool_call_rj("test_tool2", dummy_args_tc2, alloc);
// auto out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({tc1, tc2})) }), {}, false);
rapidjson::Value messages_for_parallel_calls_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy8; dummy_user_msg_copy8.CopyFrom(dummy_user_msg, alloc);
messages_for_parallel_calls_test.PushBack(dummy_user_msg_copy8, alloc);
rapidjson::Value tool_calls_array_parallel(rapidjson::kArrayType);
tool_calls_array_parallel.PushBack(tc1, alloc); // tc1, tc2 are already using alloc
tool_calls_array_parallel.PushBack(tc2, alloc);
rapidjson::Value tool_calls_msg_parallel = make_tool_calls_msg_rj(tool_calls_array_parallel, alloc);
messages_for_parallel_calls_test.PushBack(tool_calls_msg_parallel, alloc);
rapidjson::Value no_tools_copy8; no_tools_copy8.CopyFrom(no_tools, alloc);
std::string out_parallel_calls = try_raw_render(messages_for_parallel_calls_test, no_tools_copy8, false, alloc);
caps_.supports_parallel_tool_calls = contains(out_parallel_calls, "test_tool1") && contains(out_parallel_calls, "test_tool2");
// Need to re-create tc1 as it was moved into tool_calls_array_parallel
rapidjson::Value dummy_args_tc1_resp; dummy_args_tc1_resp.CopyFrom(dummy_args_for_parallel_test, alloc);
rapidjson::Value tc1_resp = make_tool_call_rj("test_tool1", dummy_args_tc1_resp, alloc);
// out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({tc1})), { ...tool response... } }), {}, false);
rapidjson::Value messages_for_tool_response_test(rapidjson::kArrayType);
rapidjson::Value dummy_user_msg_copy9; dummy_user_msg_copy9.CopyFrom(dummy_user_msg, alloc);
messages_for_tool_response_test.PushBack(dummy_user_msg_copy9, alloc);
rapidjson::Value tool_calls_array_resp(rapidjson::kArrayType);
tool_calls_array_resp.PushBack(tc1_resp, alloc);
rapidjson::Value tool_calls_msg_resp = make_tool_calls_msg_rj(tool_calls_array_resp, alloc);
messages_for_tool_response_test.PushBack(tool_calls_msg_resp, alloc);
rapidjson::Value tool_response_msg(rapidjson::kObjectType);
tool_response_msg.AddMember("role", "tool", alloc);
tool_response_msg.AddMember("name", "test_tool1", alloc);
tool_response_msg.AddMember("content", "Some response!", alloc);
tool_response_msg.AddMember("tool_call_id", "call_911_", alloc);
messages_for_tool_response_test.PushBack(tool_response_msg, alloc);
rapidjson::Value no_tools_copy9; no_tools_copy9.CopyFrom(no_tools, alloc);
std::string out_tool_response = try_raw_render(messages_for_tool_response_test, no_tools_copy9, false, alloc);
caps_.supports_tool_responses = contains(out_tool_response, "Some response!");
caps_.supports_tool_call_id = contains(out_tool_response, "call_911_");
}
if (!caps_.supports_tools) {
// const json user_msg { {"role", "user"}, {"content", "Hey"} };
rapidjson::Value user_msg_infer(rapidjson::kObjectType);
user_msg_infer.AddMember("role", "user", alloc);
user_msg_infer.AddMember("content", "Hey", alloc);
// const json args { {"arg1", "some_value"} };
rapidjson::Value args_infer(rapidjson::kObjectType);
args_infer.AddMember("arg1", "some_value", alloc);
// const json tool_call_msg { ... }
rapidjson::Value tool_call_msg_infer(rapidjson::kObjectType);
tool_call_msg_infer.AddMember("role", "assistant", alloc);
tool_call_msg_infer.AddMember("content", rapidjson::Value(rapidjson::kNullType), alloc);
rapidjson::Value tool_calls_array_infer(rapidjson::kArrayType);
rapidjson::Value tool_call_item_infer(rapidjson::kObjectType);
tool_call_item_infer.AddMember("id", "call_1___", alloc);
tool_call_item_infer.AddMember("type", "function", alloc);
rapidjson::Value function_item_infer(rapidjson::kObjectType);
function_item_infer.AddMember("name", "tool_name", alloc);
rapidjson::Value arguments_infer;
if (caps_.requires_object_arguments) {
arguments_infer.CopyFrom(args_infer, alloc);
} else {
// This requires minja::Value::dump which itself uses nlohmann::json.
// This part needs a temporary nlohmann::json to dump, or reimplement dump logic for rapidjson.
// For now, let's assume minja::Value can give us a string that rapidjson can parse,
// or we construct the string directly.
// minja::Value(args).dump(-1, /* to_json= */ true)
// This is a major dependency. For now, I'll create a simple string version.
rapidjson::StringBuffer buffer_args_infer_str;
rapidjson::Writer<rapidjson::StringBuffer> writer_args_infer_str(buffer_args_infer_str);
args_infer.Accept(writer_args_infer_str);
arguments_infer.SetString(buffer_args_infer_str.GetString(), alloc);
}
function_item_infer.AddMember("arguments", arguments_infer, alloc);
tool_call_item_infer.AddMember("function", function_item_infer, alloc);
tool_calls_array_infer.PushBack(tool_call_item_infer, alloc);
tool_call_msg_infer.AddMember("tool_calls", tool_calls_array_infer, alloc);
std::string prefix_str, full_str;
{
chat_template_inputs inputs_prefix;
inputs_prefix.allocator_for_inputs = &alloc;
inputs_prefix.messages.SetArray();
rapidjson::Value user_msg_infer_copy1; user_msg_infer_copy1.CopyFrom(user_msg_infer, alloc);
inputs_prefix.messages.PushBack(user_msg_infer_copy1, alloc);
inputs_prefix.add_generation_prompt = true;
// inputs.tools is already kNullType by default in chat_template_inputs constructor
prefix_str = apply(inputs_prefix);
}
{
chat_template_inputs inputs_full;
inputs_full.allocator_for_inputs = &alloc;
inputs_full.messages.SetArray();
rapidjson::Value user_msg_infer_copy2; user_msg_infer_copy2.CopyFrom(user_msg_infer, alloc);
inputs_full.messages.PushBack(user_msg_infer_copy2, alloc);
rapidjson::Value tool_call_msg_infer_copy; tool_call_msg_infer_copy.CopyFrom(tool_call_msg_infer, alloc);
inputs_full.messages.PushBack(tool_call_msg_infer_copy, alloc);
inputs_full.add_generation_prompt = false;
// inputs.tools is already kNullType by default
full_str = apply(inputs_full);
}
// ... rest of the logic for tool_call_example_ using prefix_str and full_str
// This part seems okay to remain as string manipulation
auto eos_pos_last = full_str.rfind(eos_token_);
if (eos_pos_last == prefix_str.size() - eos_token_.size() ||
(full_str[full_str.size() - 1] == '\n' && (eos_pos_last == full_str.size() - eos_token_.size() - 1))) {
full_str = full_str.substr(0, eos_pos_last);
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix_str.size() && i < full_str.size(); ++i) {
if (prefix_str[i] != full_str[i]) {
break;
}
if (prefix_str[i] == '<') {
continue;
}
common_prefix_length = i + 1;
}
auto example = full_str.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
}
// Ensure d_render_test is cleared if it were a member, but it's local.
#endif
}
std::string chat_template::try_raw_render(
rapidjson::Value& messages, // Modifying to pass by ref as it might be changed by polyfills later
rapidjson::Value& tools, // Modifying to pass by ref
bool add_generation_prompt,
rapidjson::Document::AllocatorType& allocator, // Added allocator
rapidjson::Value extra_context) const // Default to null
{
chat_template_inputs inputs;
// Important: When assigning Value, if it's from another Document or a temporary,
// it needs to be deep copied using the allocator of the target Document/Value.
// For try_raw_render, we assume messages, tools, extra_context are already managed
// or will be properly constructed with an allocator.
// Here, we're creating new Value objects for the inputs struct, so they need an allocator
// if they are to be populated. However, inputs here is temporary.
// The original nlohmann version copied, rapidjson Value assignment is a shallow copy.
// This needs careful handling. For now, let's assume the caller manages lifetime.
// This is tricky because the Value objects in chat_template_inputs need an allocator.
// Let's try to pass the allocator to inputs.
inputs.allocator_for_inputs = &allocator;
inputs.messages.CopyFrom(messages, allocator);
inputs.tools.CopyFrom(tools, allocator);
inputs.add_generation_prompt = add_generation_prompt;
if (!extra_context.IsNull()) {
inputs.extra_context.CopyFrom(extra_context, allocator);
} else {
inputs.extra_context.SetObject(); // Initialize as empty object if default
}
// Use fixed date for tests
inputs.now = std::chrono::system_clock::from_time_t(0);
chat_template_options opts;
opts.apply_polyfills = false;
auto prompt = apply(inputs, opts);
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
return prompt;
}
std::string chat_template::apply(
chat_template_inputs & inputs,
const chat_template_options & opts) const {
AUTOTIME;
// Create a working document for this apply call.
// All new JSON Values created within this scope should use its allocator.
Document working_doc;
rapidjson::Document::AllocatorType& allocator = working_doc.GetAllocator();
rapidjson::Value actual_messages(rapidjson::kArrayType); // Uses working_doc's allocator by default if created here
auto has_tools = inputs.tools.IsArray() && !inputs.tools.Empty();
auto has_tool_calls = false;
auto has_tool_responses = false;
auto has_string_content = false;
if (inputs.messages.IsArray()) {
for (const auto & message_val : inputs.messages.GetArray()) {
if (message_val.IsObject()) {
if (message_val.HasMember("tool_calls") && !message_val["tool_calls"].IsNull()) {
has_tool_calls = true;
}
if (message_val.HasMember("role") && message_val["role"].IsString() &&
strcmp(message_val["role"].GetString(), "tool") == 0) {
has_tool_responses = true;
}
if (message_val.HasMember("content") && message_val["content"].IsString()) {
has_string_content = true;
}
}
}
}
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
auto needs_polyfills = opts.apply_polyfills && (false
|| polyfill_system_role
|| polyfill_tools
|| polyfill_tool_calls
|| polyfill_tool_responses
|| polyfill_object_arguments
|| polyfill_typed_content
);
if (needs_polyfills) {
// actual_messages is already an empty array, using allocator
auto add_message = [&](const rapidjson::Value & msg_const) {
rapidjson::Value msg;
msg.CopyFrom(msg_const, allocator); // Ensure it uses the current doc's allocator
if (polyfill_typed_content && msg.IsObject() && msg.HasMember("content") &&
!msg["content"].IsNull() && msg["content"].IsString()) {
rapidjson::Value new_msg(rapidjson::kObjectType);
new_msg.AddMember("role", rapidjson::Value(msg["role"], allocator), allocator); // copy role
rapidjson::Value content_array_typed(rapidjson::kArrayType);
rapidjson::Value content_item_typed(rapidjson::kObjectType);
content_item_typed.AddMember("type", "text", allocator);
// Need to copy the string content for "text"
rapidjson::Value text_val(msg["content"].GetString(), allocator);
content_item_typed.AddMember("text", text_val, allocator);
content_array_typed.PushBack(content_item_typed, allocator);
new_msg.AddMember("content", content_array_typed, allocator);
actual_messages.PushBack(new_msg, allocator);
} else {
actual_messages.PushBack(msg, allocator); // msg already copied with allocator
}
};
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
rapidjson::Value sys_as_user_msg(rapidjson::kObjectType);
sys_as_user_msg.AddMember("role", "user", allocator);
sys_as_user_msg.AddMember("content", rapidjson::StringRef(pending_system.c_str()), allocator);
add_message(sys_as_user_msg); // add_message will handle typed content if needed
pending_system.clear();
}
};
rapidjson::Value adjusted_messages_val(rapidjson::kArrayType);
if (polyfill_tools) {
// Convert inputs.tools to string for the system prompt
rapidjson::StringBuffer tools_buffer;
rapidjson::PrettyWriter<rapidjson::StringBuffer> tools_writer(tools_buffer); // Pretty for readability
tools_writer.SetIndent(' ', 2);
inputs.tools.Accept(tools_writer);
std::string tools_str_prompt = tools_buffer.GetString();
std::string system_prompt_str =
"You can call any of the following tools to satisfy the user's requests: " + tools_str_prompt +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n");
// add_system returns a new Value, ensure it uses 'allocator'
rapidjson::Value messages_copy_for_add_system;
messages_copy_for_add_system.CopyFrom(inputs.messages, allocator);
adjusted_messages_val = add_system(messages_copy_for_add_system, system_prompt_str, allocator);
} else {
adjusted_messages_val.CopyFrom(inputs.messages, allocator);
}
if (adjusted_messages_val.IsArray()){
for (auto & message_val_mut : adjusted_messages_val.GetArray()) { // Iterate by mutable ref
// message_ is already using 'allocator' as it's part of adjusted_messages_val
rapidjson::Value message; // Create a mutable copy for this iteration
message.CopyFrom(message_val_mut, allocator);
if (!message.IsObject() || !message.HasMember("role") || !message.HasMember("content")) {
// MNN_ERROR replacement:
fprintf(stderr, "message must have 'role' and 'content' fields: %s\n", valueToString(message).c_str());
// Potentially skip this message or handle error
continue;
}
const char* role_cstr = message["role"].GetString();
std::string role = role_cstr;
if (message.HasMember("tool_calls")) {
if (polyfill_object_arguments || polyfill_tool_calls) {
if (message["tool_calls"].IsArray()) {
for (auto & tool_call_val : message["tool_calls"].GetArray()) {
if (tool_call_val.IsObject() && tool_call_val.HasMember("type") && tool_call_val["type"] == "function") {
if (tool_call_val.HasMember("function") && tool_call_val["function"].IsObject()) {
auto& function_val = tool_call_val["function"];
if (function_val.HasMember("arguments") && function_val["arguments"].IsString()) {
std::string args_str = function_val["arguments"].GetString();
Document args_doc;
if (!args_doc.Parse(args_str.c_str()).HasParseError()) {
// Replace the string arguments with the parsed Value object
// The new Value must use 'allocator'
rapidjson::Value new_args_val;
new_args_val.CopyFrom(args_doc, allocator);
function_val["arguments"].Swap(new_args_val); // Swap to avoid copy if possible
}
}
}
}
}
}
}
if (polyfill_tool_calls) {
rapidjson::Value content_val; content_val.CopyFrom(message["content"], allocator); // Keep original content if any
rapidjson::Value tool_calls_payload(rapidjson::kArrayType);
if (message["tool_calls"].IsArray()) {
for (const auto & tool_call_val_const : message["tool_calls"].GetArray()) {
if (tool_call_val_const.IsObject() && tool_call_val_const.HasMember("type") && tool_call_val_const["type"] == "function") {
const auto& function_val_const = tool_call_val_const["function"];
rapidjson::Value tc_item(rapidjson::kObjectType);
tc_item.AddMember("name", rapidjson::Value(function_val_const["name"], allocator), allocator);
// Arguments should already be objects if polyfill_object_arguments ran
tc_item.AddMember("arguments", rapidjson::Value(function_val_const["arguments"], allocator), allocator);
if (tool_call_val_const.HasMember("id")) {
tc_item.AddMember("id", rapidjson::Value(tool_call_val_const["id"], allocator), allocator);
}
tool_calls_payload.PushBack(tc_item, allocator);
}
}
}
rapidjson::Value obj_for_content(rapidjson::kObjectType);
obj_for_content.AddMember("tool_calls", tool_calls_payload, allocator);
if (!content_val.IsNull() && !(content_val.IsString() && strlen(content_val.GetString()) == 0)) {
obj_for_content.AddMember("content", content_val, allocator);
}
// Serialize obj_for_content to string for message["content"]
rapidjson::StringBuffer s_buffer;
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer_obj(s_buffer);
writer_obj.SetIndent(' ', 2);
obj_for_content.Accept(writer_obj);
message["content"].SetString(s_buffer.GetString(), allocator);
message.RemoveMember("tool_calls");
}
}
if (polyfill_tool_responses && role == "tool") {
message["role"].SetString("user", allocator); // Change role to user
rapidjson::Value tool_response_obj(rapidjson::kObjectType);
rapidjson::Value tool_response_inner_obj(rapidjson::kObjectType);
if (message.HasMember("name")) {
tool_response_inner_obj.AddMember("tool", rapidjson::Value(message["name"], allocator), allocator);
}
// message["content"] is guaranteed to exist by check above
tool_response_inner_obj.AddMember("content", rapidjson::Value(message["content"], allocator), allocator);
if (message.HasMember("tool_call_id")) {
tool_response_inner_obj.AddMember("tool_call_id", rapidjson::Value(message["tool_call_id"], allocator), allocator);
}
tool_response_obj.AddMember("tool_response", tool_response_inner_obj, allocator);
// Serialize tool_response_obj to string for message["content"]
rapidjson::StringBuffer s_buffer_resp;
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer_resp(s_buffer_resp);
writer_resp.SetIndent(' ',2);
tool_response_obj.Accept(writer_resp);
message["content"].SetString(s_buffer_resp.GetString(), allocator);
if (message.HasMember("name")) message.RemoveMember("name");
if (message.HasMember("tool_call_id")) message.RemoveMember("tool_call_id"); // if it was there
}
if (!message["content"].IsNull() && polyfill_system_role) {
// Assuming content is string after previous polyfills or by its nature
std::string content_str;
if (message["content"].IsString()){
content_str = message["content"].GetString();
} else {
// If content is not string (e.g. array for typed content), it needs to be stringified for pending_system
// This case should be handled by typed_content polyfill first if active
// For simplicity, if it's not string here, we might skip or stringify it
rapidjson::StringBuffer temp_s_buffer;
rapidjson::Writer<rapidjson::StringBuffer> temp_writer(temp_s_buffer);
message["content"].Accept(temp_writer);
content_str = temp_s_buffer.GetString();
}
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content_str;
// This message is consumed, skip adding it directly
// A continue here would skip the 'add_message(message)' below for system messages
// which is the desired behavior.
// However, the original code structure adds the modified message (if not system)
// or flushes system messages.
// Let's ensure this message isn't added by 'add_message' if it's system.
// The flush_sys() and add_message(message) logic outside the loop handles it.
// So, if role is system, we just update pending_system and the message itself is not added.
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
std::string new_content = pending_system + (content_str.empty() ? "" : "\n" + content_str);
message["content"].SetString(new_content.c_str(), allocator);
pending_system.clear();
}
} else { // assistant, tool (already transformed to user)
flush_sys();
}
}
}
add_message(message); // add_message handles copying to actual_messages with allocator
}
}
flush_sys();
} else { // no polyfills needed
actual_messages.CopyFrom(inputs.messages, allocator);
}
auto context = minja::Context::make(nullptr); // nlohmann::json() equivalent for context data
// The make function needs to be adapted for rapidjson::Value
// For now, creating an empty object for context data.
rapidjson::Value context_data_val(rapidjson::kObjectType);
context_data_val.AddMember("messages", actual_messages, allocator); // actual_messages already uses allocator
context_data_val.AddMember("add_generation_prompt", inputs.add_generation_prompt, allocator);
// Convert context_data_val to nlohmann::json for minja::Context::make
// This is a temporary bridge. minja::Context itself needs to be updated for rapidjson.
// This is a critical dependency.
context = minja::Context::make(minja::Value(context_data_val));
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
if (opts.define_strftime_now) {
auto time_now_capture = inputs.now; // capture for lambda
context->set("strftime_now", minja::Value::callable([time_now_capture](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
args.expectArgs("strftime_now", {1, 1}, {0, 0});
auto format = args.args[0].get<std::string>();
auto time_point = std::chrono::system_clock::to_time_t(time_now_capture);
auto local_time = *std::localtime(&time_point);
std::ostringstream ss;
ss << std::put_time(&local_time, format.c_str());
return ss.str();
}));
}
if (!inputs.tools.IsNull()) {
context->set("tools", minja::Value(inputs.tools));
}
if (!inputs.extra_context.IsNull() && inputs.extra_context.IsObject()) {
for (auto & kv : inputs.extra_context.GetObject()) {
context->set(kv.name.GetString(), minja::Value(kv.value));
}
}
auto ret = template_root_->render(context);
return ret;
}
rapidjson::Value chat_template::add_system(
const rapidjson::Value & messages_const, // input messages (const ref)
const std::string & system_prompt,
rapidjson::Document::AllocatorType& allocator) {
rapidjson::Value messages_with_system(rapidjson::kArrayType);
messages_with_system.CopyFrom(messages_const, allocator); // Deep copy to make it modifiable
if (!messages_with_system.Empty() && messages_with_system[0].IsObject() &&
messages_with_system[0].HasMember("role") && messages_with_system[0]["role"] == "system") {
std::string existing_system_content_str;
if (messages_with_system[0].HasMember("content") && messages_with_system[0]["content"].IsString()) {
existing_system_content_str = messages_with_system[0]["content"].GetString();
}
std::string new_content_str = existing_system_content_str + "\n\n" + system_prompt;
messages_with_system[0]["content"].SetString(new_content_str.c_str(), allocator);
} else {
rapidjson::Value new_system_msg(rapidjson::kObjectType);
new_system_msg.AddMember("role", "system", allocator);
new_system_msg.AddMember("content", rapidjson::StringRef(system_prompt.c_str()), allocator);
// Insert at the beginning
rapidjson::Value temp_array(rapidjson::kArrayType);
temp_array.PushBack(new_system_msg, allocator);
for (auto& el : messages_with_system.GetArray()) {
rapidjson::Value el_copy;
el_copy.CopyFrom(el, allocator);
temp_array.PushBack(el_copy, allocator);
}
messages_with_system.Swap(temp_array);
}
return messages_with_system; // This Value is allocated with 'allocator'
}
};
#endif

View File

@ -0,0 +1,112 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include <chrono>
#include <cstddef>
#include <cstdio>
#include <ctime>
#include <iomanip>
#include <memory>
#include <string>
#include <vector>
#include "rapidjson/document.h"
// Forward declaration for Value used in Minja
namespace minja {
class Value;
class TemplateNode;
}
using Document = rapidjson::Document;
// Note: rapidjson::Value is the type for all JSON values (objects, arrays, strings, numbers, booleans, null).
// rapidjson::Document inherits from rapidjson::Value and holds the memory allocations for the DOM.
// We will use rapidjson::Value where nlohmann::json was used for individual values,
// and rapidjson::Document where a new JSON structure was being parsed or built.
namespace minja {
struct chat_template_caps {
bool supports_tools = false;
bool supports_tool_calls = false;
bool supports_tool_responses = false;
bool supports_system_role = false;
bool supports_parallel_tool_calls = false;
bool supports_tool_call_id = false;
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments = false;
// CohereForAI/c4ai-command-r-plus simple variant
bool requires_non_null_content = false;
// MiniMaxAI/MiniMax-Text-01 special
bool requires_typed_content = false;
};
struct chat_template_inputs {
rapidjson::Document messages; // Should be an array
rapidjson::Document tools; // Should be an array or null
bool add_generation_prompt = true;
rapidjson::Document extra_context; // Should be an object or null
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
rapidjson::Document::AllocatorType* allocator_for_inputs = nullptr; // To be set when creating inputs
// Default constructor to initialize Value members
chat_template_inputs() : messages(rapidjson::kArrayType), tools(rapidjson::kNullType), extra_context(rapidjson::kNullType) {}
};
struct chat_template_options {
bool apply_polyfills = true;
bool use_bos_token = true;
bool use_eos_token = true;
bool define_strftime_now = true;
bool polyfill_tools = true;
bool polyfill_tool_call_examples = true;
bool polyfill_tool_calls = true;
bool polyfill_tool_responses = true;
bool polyfill_system_role = true;
bool polyfill_object_arguments = true;
bool polyfill_typed_content = true;
};
class chat_template {
private:
chat_template_caps caps_;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;
std::string tool_call_example_;
std::string try_raw_render(
rapidjson::Value& messages, // Modifying to pass by ref as it might be changed by polyfills later
rapidjson::Value& tools, // Modifying to pass by ref
bool add_generation_prompt,
rapidjson::Document::AllocatorType& allocator, // Added allocator
rapidjson::Value extra_context = rapidjson::Value(rapidjson::kNullType)) const; // Default to null
public:
MNN_PUBLIC chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token);
const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
const chat_template_caps & original_caps() const { return caps_; }
MNN_PUBLIC std::string apply(
chat_template_inputs & inputs,
const chat_template_options & opts = chat_template_options()) const;
static rapidjson::Value add_system(
const rapidjson::Value & messages_const, // input messages (const ref)
const std::string & system_prompt,
rapidjson::Document::AllocatorType& allocator);
};
};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -649,10 +649,13 @@ VARP Omni::gen_position_ids(int seq_len) {
positionIds = _Input({3, seq_len}, NCHW, halide_type_of<int>());
}
auto ptr = positionIds->writeMap<int>();
if (seq_len == 1) {
ptr[0] = mContext->gen_seq_len + mPositionIds.back();
ptr[1] = ptr[0];
ptr[2] = ptr[0];
if (mContext->gen_seq_len > 0) {
for (int i=0; i<seq_len; ++i) {
auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
ptr[i + 0] = pos;
ptr[i + seq_len] = pos;
ptr[i + seq_len * 2] = pos;
}
} else {
for (int i = 0; i < seq_len; i++) {
ptr[i] = mPositionIds.mT[i];
@ -666,23 +669,12 @@ VARP Omni::gen_position_ids(int seq_len) {
return positionIds;
}
Express::VARP Omni::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) {
VARP logits;
auto logitsIndex = _var<int>({-1}, {1});
if (mConfig->all_logits()) {
logitsIndex = _var<int>({0}, {1});
}
std::vector<Express::VARP> outputs;
outputs = mCurrentModules.back()->onForward({hiddenState, mask, inputPos, logitsIndex});
if (outputs.empty()) {
return nullptr;
}
std::vector<Express::VARP> Omni::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) {
auto outputs = Llm::forwardRaw(hiddenState, mask, inputPos);
if (mTalker && outputs.size() > 1) {
mTalker->addTalkerEmbeds(outputs[1]);
}
logits = outputs[0];
mMeta->sync();
return logits;
return outputs;
}
void Omni::response(const std::vector<int>& input_ids, std::ostream* os, const char* end_with, int max_new_tokens) {
@ -733,6 +725,7 @@ void Omni::generateWavform() {
void Talker::load() {
initRuntime();
mSeqLenIndex = 1;
set_config("{\"sampler_type\": \"mixed\", \"temperature\": 0.9, \"topK\": 40, \"topP\": 0.8, \"penalty\": 1.05}");
mSampler.reset(Sampler::createSampler(mContext, mConfig));
mDiskEmbedding.reset(new DiskEmbedding(mConfig, mConfig->talker_embedding_file()));
@ -753,7 +746,9 @@ void Talker::load() {
module_config.shapeMutable = false;
module_config.rearrange = true;
mModules.resize(1);
mModules[0].reset(Module::load({"inputs_embeds", "attention_mask", "position_ids"},
std::vector<std::string> inputNames {"inputs_embeds", "attention_mask", "position_ids", "logits_index"};
mModules[0].reset(Module::load(inputNames,
{"logits"}, mConfig->talker_model().c_str(), mRuntimeManager, &module_config));
// dit
mPreDit.reset(Module::load({"cond", "spk", "code"}, {"code_embeds", "rope", "mask"},
@ -770,24 +765,7 @@ void Talker::load() {
void Talker::generate_init(std::ostream* os, const char* end_with) {
if (!doGenerate()) { return; }
{
mContext->os = os;
if (nullptr != end_with) {
mContext->end_with = end_with;
}
if (!mContext->generate_str.empty()) {
mContext->generate_str.clear();
}
mContext->gen_seq_len = 0;
mContext->prefill_us = 0;
mContext->decode_us = 0;
mContext->current_token = 0;
mContext->all_seq_len = 0;
mContext->history_tokens.clear();
mMeta->remove = mMeta->previous;
mContext->output_tokens.clear();
mCurrentModules = mPrefillModules;
}
Llm::generate_init(os, end_with);
// stream generate init
mTalkerEmbeds.clear();
if (mInitialNoise.empty()) {
@ -947,19 +925,6 @@ int Talker::sample(Express::VARP logits, int offset, int size) {
return token;
}
VARP Talker::forward(VARP input_embeds) {
auto input_shape = input_embeds->getInfo()->dim;
int seq_len = input_shape[1];
mMeta->add = seq_len;
auto attention_mask = gen_attention_mask(seq_len);
auto position_ids = gen_position_ids(seq_len);
auto outputs = mCurrentModules.back()->onForward({input_embeds, attention_mask, position_ids});
mContext->all_seq_len += seq_len;
mContext->gen_seq_len++;
mMeta->sync();
return outputs[0];
}
void Talker::generate() {
if (!doGenerate()) { return; }
mTalkerEmbeds.push_back(mTextEos);
@ -970,11 +935,13 @@ void Talker::generate() {
mContext->prompt_len = input_embeds->getInfo()->dim[1];
MNN::Timer _t;
auto logits = forward(input_embeds);
int token = sample(logits);
mContext->current_token = sample(logits);
mContext->history_tokens.push_back(mContext->current_token);
mContext->output_tokens.push_back(mContext->current_token);
mContext->prefill_us += _t.durationInUs();
_t.reset();
for (int i = 1; i < mMaxNewTokens; i++) {
input_embeds = embedding({token});
input_embeds = embedding({mContext->current_token});
if (i + 1 < mTalkerEmbeds.size()) {
input_embeds = input_embeds + mTalkerEmbeds[i + 1];
} else {
@ -982,8 +949,11 @@ void Talker::generate() {
input_embeds = input_embeds + mTextPad;
}
auto logits = forward(input_embeds);
token = sample(logits);
if (token == 8292 || token == 8294) {
mContext->current_token = sample(logits);
mContext->history_tokens.push_back(mContext->current_token);
mContext->output_tokens.push_back(mContext->current_token);
if (mContext->current_token == 8292 || mContext->current_token == 8294) {
break;
}
}

Some files were not shown because too many files have changed in this diff Show More