diff --git a/docs/compile/engine.md b/docs/compile/engine.md index c2d89540..440e1151 100644 --- a/docs/compile/engine.md +++ b/docs/compile/engine.md @@ -126,7 +126,7 @@ mkdir build && cd build && cmake .. -DCMAKE_OSX_ARCHITECTURES=arm64 && make -j8 - 基于脚本编译:运行脚本并开启`MNN_ARM82`选项 ``` -sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true" +sh package_scripts/ios/buildiOS.sh -DMNN_ARM82=ON ``` ## 鸿蒙(Harmony) diff --git a/docs/start/demo.md b/docs/start/demo.md index b16cb1e9..96997263 100644 --- a/docs/start/demo.md +++ b/docs/start/demo.md @@ -20,9 +20,11 @@ ### 图像实例分割 代码位置:`demo/exec/segment.cpp` -下载 deeplabv3 分割模型并转换到 mnn 模型 +下载 deeplabv3 分割模型 [https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite](https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/deeplabv3_257_mv_gpu.tflite) +使用 [模型转换工具](../tools/convert.md) 转换为 MNN 模型,转换时加上参数 --keepInputFormat=0 【把输入由NHWC转换为NC4HW4布局】 + ```bash ./segment.out model.mnn input.png result.png ``` @@ -95,14 +97,14 @@ flops_info: 568.792175M backend_info: 13 expect 983 output belong to class: 983 -$ python gpu_session_demo.py mobilenet_demo/mobilenet_v1.mnn mobilenet_demo/ILSVRC2012_val_00049999.JPEG +$ python gpu_session_demo.py mobilenet_demo/mobilenet_v1.mnn mobilenet_demo/ILSVRC2012_val_00049999.JPEG Testing gpu model calling method Load Cache file error. MNN use high precision Can't Find type=3 backend, use 0 instead Can't Find type=3 backend, use 0 instead -Run on backendtype: 13 +Run on backendtype: 13 expect 983 output belong to class: 983 @@ -127,7 +129,7 @@ output belong to class: 983 #### mnist 使用mnist数据训练模型,并测试准确率,无需下载资源,用法如下: ```bash -$ pip install mnist +$ pip install mnist $ python train_mnist.py train loss: 2.3346531 train loss: 0.28027835 @@ -161,7 +163,7 @@ AttributeError: module 'MNN.nn' has no attribute 'FixModule' #### module_save 演示了模型权值的存储和加载 ```bash -$ python test_save.py +$ python test_save.py 0.0004 10 ``` @@ -225,4 +227,3 @@ sh ../tools/script/get_model.sh - [视频抠图](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) - [SuperGlue关键点匹配](https://github.com/Hanson0910/MNNSuperGlue) - [OCR](https://github.com/DayBreak-u/chineseocr_lite/tree/onnx/android_projects/OcrLiteAndroidMNN) -- [Bert-VITS2-MNN](https://github.com/Voine/Bert-VITS2-MNN) \ No newline at end of file diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index 0bd0fc37..5027f9f3 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -73,14 +73,17 @@ python llmexport.py \ - 使用`--lm_quant_bit`来制定lm_head层权重的量化bit数,不指定则使用`--quant_bit`的量化bit数 ### 参数 +执行 `python llmexport.py -h` 可查看参数: ``` -usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT] - [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT] - [--mnnconvert MNNCONVERT] +usage: llmexport.py [-h] --path PATH [--type TYPE] [--tokenizer_path TOKENIZER_PATH] [--lora_path LORA_PATH] + [--gptq_path GPTQ_PATH] [--dst_path DST_PATH] [--verbose] [--test TEST] [--export EXPORT] + [--onnx_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] + [--lm_quant_bit LM_QUANT_BIT] [--mnnconvert MNNCONVERT] [--ppl] [--awq] [--sym] [--seperate_embed] + [--lora_split] llm_exporter -options: +optional arguments: -h, --help show this help message and exit --path PATH path(`str` or `os.PathLike`): Can be either: @@ -88,19 +91,30 @@ options: - A path to a *directory* clone from repo like `../chatglm-6b`. --type TYPE type(`str`, *optional*): The pretrain llm model type. + --tokenizer_path TOKENIZER_PATH + tokenizer path, defaut is `None` mean using `--path` value. --lora_path LORA_PATH lora path, defaut is `None` mean not apply lora. + --gptq_path GPTQ_PATH + gptq path, defaut is `None` mean not apply gptq. --dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`. + --verbose Whether or not to print verbose. --test TEST test model inference with query `TEST`. --export EXPORT export model to an onnx/mnn model. + --onnx_slim Whether or not to use onnx-slim. --quant_bit QUANT_BIT mnn quant bit, 4 or 8, default is 4. --quant_block QUANT_BLOCK - mnn quant block, default is 0 mean channle-wise. + mnn quant block, 0 mean channle-wise, default is 128. --lm_quant_bit LM_QUANT_BIT mnn lm_head quant bit, 4 or 8, default is `quant_bit`. --mnnconvert MNNCONVERT local mnnconvert path, if invalid, using pymnn. + --ppl Whether or not to get all logits of input tokens. + --awq Whether or not to use awq quant. + --sym Whether or not to using symmetric quant (without zeropoint), defualt is False. + --seperate_embed For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin. + --lora_split Whether or not export lora split, defualt is False. ``` ### 权重读取 diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index c4a118cb..2a8c666f 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h @@ -75,7 +75,7 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ #define STR_IMP(x) #x #define STR(x) STR_IMP(x) #define MNN_VERSION_MAJOR 3 -#define MNN_VERSION_MINOR 1 -#define MNN_VERSION_PATCH 4 +#define MNN_VERSION_MINOR 2 +#define MNN_VERSION_PATCH 0 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #endif /* MNNDefine_h */ diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp index b7cc7e0a..07988562 100644 --- a/source/backend/cpu/CPUAttention.cpp +++ b/source/backend/cpu/CPUAttention.cpp @@ -99,14 +99,22 @@ static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_ template static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* mask) { if (seq_len == 1 || mask == nullptr) { - for (int i = 0; i < seq_len * kv_seq_len; i++) { + for (int i = 0; i < kv_seq_len; i++) { unpack_qk[i] = unpack_qk[i] * mScale; } } else if (mask->getType() == halide_type_of()) { // float mask T* fpmask_ptr = mask->host(); - for (int i = 0; i < seq_len * kv_seq_len; i++) { - unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i]; + int offset = kv_seq_len-seq_len; + for (int i=0; i& inputs, const std: int seq_len = query->length(1); if (inputs.size() > 3) { mask = inputs[3]; - MNN_ASSERT(seq_len == mask->length(2)); } int tileCount = UP_DIV(mNumHead, mThreadNum); int group_size = mNumHead / mKvNumHead; diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index 4180c94b..0cc42240 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -594,7 +594,7 @@ ErrorCode CPURaster::onExecute(const std::vector &____inputs, const st } auto core = static_cast(backend())->functions(); auto output = outputs[0]; - auto bytes = CPUBackend::getBytes(backend(), output); + size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); auto outputEleSize = static_cast(backend())->getTensorSize(output); auto threadNum = static_cast(backend())->threadNumber(); if (mSingleConvert.type > 0) { diff --git a/source/backend/metal/MetalLoop.mm b/source/backend/metal/MetalLoop.mm index 85010045..205cc08c 100644 --- a/source/backend/metal/MetalLoop.mm +++ b/source/backend/metal/MetalLoop.mm @@ -72,7 +72,7 @@ struct d0 T data[1]; }; -kernel void main0(device d0& uOutput [[buffer(0)]], const device s0& uInputA [[buffer(1)]], const device s1& uInputB [[buffer(2)]], +kernel void loop_matmul(device d0& uOutput [[buffer(0)]], const device s0& uInputA [[buffer(1)]], const device s1& uInputB [[buffer(2)]], #ifdef HAS_BIAS const device s2& uInputC [[buffer(3)]], const device s3& uOOffset [[buffer(4)]], @@ -198,7 +198,7 @@ public: @"HAS_BIAS":@"1", }; } - pipeline = mtbn->makeComputePipelineWithSourceOption(gMatMulUnitTemplate, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gMatMulUnitTemplate, "loop_matmul", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { @@ -268,6 +268,7 @@ struct constBuffer int4 extent; int4 _step; int4 iter; + int4 totalSize; }; struct s1 @@ -290,7 +291,7 @@ struct s0 T data[1]; }; -kernel void main0(device sourceBuffer& uOutput [[buffer(0)]], const device s0& uInput [[buffer(1)]], const device s1& uSrcOffset [[buffer(2)]], const device s2& uDstOffset [[buffer(3)]], constant constBuffer& uConstant [[buffer(4)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +kernel void gather_blit(device sourceBuffer& uOutput [[buffer(0)]], const device s0& uInput [[buffer(1)]], const device s1& uSrcOffset [[buffer(2)]], const device s2& uDstOffset [[buffer(3)]], constant constBuffer& uConstant [[buffer(4)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) { int3 posTmp = int3(gl_GlobalInvocationID); if (posTmp.x < uConstant._step.w) @@ -322,7 +323,11 @@ kernel void main0(device sourceBuffer& uOutput [[buffer(0)]], const device s0& u } int srcOffset = (((srcBasicOffset + uConstant.stride.w) + (uConstant.stride.z * pos.z)) + (uConstant.stride.y * pos.y)) + (uConstant.stride.x * pos.x); int dstOffset = (((dstBasicOffset + uConstant.extent.w) + (pos.x * uConstant.extent.x)) + (pos.y * uConstant.extent.y)) + (pos.z * uConstant.extent.z); - uOutput.data[dstOffset] = uInput.data[srcOffset]; + if(srcOffset >= 0 && srcOffset < uConstant.totalSize.x) { + if(dstOffset >= 0 && dstOffset < uConstant.totalSize.y) { + uOutput.data[dstOffset] = uInput.data[srcOffset]; + } + } } } )metal"; @@ -333,49 +338,139 @@ struct GatherInfo { int extent[4]; int step[4]; int iter[4]; + int totalSize[4]; +}; +struct InitInfo { + int srcStride[4]; + int dstStride[4]; + int size[4]; + int totalSize[4]; +}; +static const char* gInitRegion = R"metal( +#include +#include +using namespace metal; + +struct constBuffer +{ + int4 srcStride; + int4 dstStride; + int4 size; + int4 totalSize; }; +kernel void set_zero(device T *out [[buffer(0)]], + const device T *in [[buffer(1)]], + constant constBuffer &info [[buffer(2)]], + uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) { + int3 gid = int3(gl_GlobalInvocationID); + if (gid.x >= info.size.x || gid.y >= info.size.y || gid.z >= info.size.z) { + return; + } + int dst_offset = (gid.z * info.size.y + gid.y) * info.size.x + gid.x; + if(dst_offset >= 0 && dst_offset < info.totalSize.y) { + out[dst_offset] = (T)0; + } +} + +kernel void set_copy(device T *out [[buffer(0)]], + const device T *in [[buffer(1)]], + constant constBuffer &info [[buffer(2)]], + uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) { + int3 gid = int3(gl_GlobalInvocationID); + if (gid.x >= info.size.x || gid.y >= info.size.y || gid.z >= info.size.z) { + return; + } + int src_offset = gid.x * info.srcStride.x + gid.y * info.srcStride.y + gid.z * info.srcStride.z; + int dst_offset = gid.x * info.dstStride.x + gid.y * info.dstStride.y + gid.z * info.dstStride.z; + if(src_offset >= 0 && src_offset < info.totalSize.x) { + if(dst_offset >= 0 && dst_offset < info.totalSize.y) { + out[dst_offset] = in[src_offset]; + } + } +} +)metal"; + class MetalGather : public MetalExecution { private: const LoopParam* mLoop; + bool mNeedInit = false; + std::pair mInitThreads; id mParam; id mPipeline; + id mInitPipeline; + id mInitParam; std::vector mTensors; public: MetalGather(const LoopParam* loop, Backend *bn, const std::vector& inputs, const std::vector& outputs) : MetalExecution(bn) { mLoop = loop; auto mtbn = static_cast(bn); auto context = (__bridge MNNMetalContext *)mtbn->context(); + mParam = [context newDeviceBuffer:sizeof(GatherInfo) access:CPUWriteOnly]; bool useFp16 = mtbn->useFp16InsteadFp32(); mTensors.resize(mLoop->tensorNumber()); auto cmd = mLoop->commands()->GetAs(0); _setTensorStack(mTensors, inputs, outputs, mLoop); auto dstTensor = mTensors[cmd->indexes()->data()[0]]; - + NSString* T = MetalCast::getScalarType(dstTensor->getType(), useFp16); - std::vector keys = { - std::string([T UTF8String]), - "blitregion" - }; - auto pipeline = mtbn->runtime()->findPipeline(keys); - if (nil == pipeline) { - MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init]; - compileOptions.preprocessorMacros = @{ - @"T" : T, + // gather blit command pipeline + { + std::vector keys = { + std::string([T UTF8String]), + "blitregion" }; - pipeline = mtbn->makeComputePipelineWithSourceOption(gBlitRegion, "main0", compileOptions); - mtbn->runtime()->insertPipeline(keys, pipeline); + auto pipeline = mtbn->runtime()->findPipeline(keys); + if (nil == pipeline) { + MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init]; + compileOptions.preprocessorMacros = @{ + @"T" : T, + }; + pipeline = mtbn->makeComputePipelineWithSourceOption(gBlitRegion, "gather_blit", compileOptions); + mtbn->runtime()->insertPipeline(keys, pipeline); + } + if (nil == pipeline) { + MNN_ERROR("Create gather pipeline error\n"); + } + mPipeline = pipeline; } - if (nil == pipeline) { - MNN_ERROR("Create gather pipeline error\n"); + + // scatter need init command pipeline + if(mLoop->initCommand() != nullptr){ + mNeedInit = true; + std::string shader = "set_copy"; + auto cmd = mLoop->initCommand()->GetAs(0); + if (cmd->op() == nullptr){ + shader = "set_zero"; + } else { + mInitParam = [context newDeviceBuffer:sizeof(InitInfo) access:CPUWriteOnly]; + } + std::vector keys = { + std::string([T UTF8String]), + "init_region", + shader + }; + auto pipeline = mtbn->runtime()->findPipeline(keys); + if (nil == pipeline) { + MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init]; + compileOptions.preprocessorMacros = @{ + @"T" : T, + }; + pipeline = mtbn->makeComputePipelineWithSourceOption(gInitRegion, shader.c_str(), compileOptions); + mtbn->runtime()->insertPipeline(keys, pipeline); + } + if (nil == pipeline) { + MNN_ERROR("Create gather init pipeline error\n"); + } + mInitPipeline = pipeline; } - mPipeline = pipeline; } virtual ~MetalGather() = default; virtual ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override { auto cmd = mLoop->commands()->GetAs(0); _setTensorStack(mTensors, inputs, outputs, mLoop); + auto srcStride = cmd->view()->GetAs(1)->stride()->data(); auto dstStride = cmd->view()->GetAs(0)->stride()->data(); auto size = cmd->size()->data(); @@ -401,22 +496,70 @@ public: if (iterIndex[1] >= 0) { param->iter[1] = 1; } + + auto dstTensor = mTensors[cmd->indexes()->data()[0]]; + auto srcTensor = mTensors[cmd->indexes()->data()[1]]; + auto inputSize = srcTensor->usize() / srcTensor->buffer().type.bytes(); + auto outputSize = dstTensor->usize() / dstTensor->buffer().type.bytes(); + param->totalSize[0] = inputSize; + param->totalSize[1] = outputSize; + + if(mNeedInit) { + auto initCmd = mLoop->initCommand()->GetAs(0); + auto data = reinterpret_cast([mInitParam contents]); + + auto srcStride = initCmd->view()->GetAs(1)->stride()->data(); + auto dstStride = initCmd->view()->GetAs(0)->stride()->data(); + auto dataSize = initCmd->size()->data(); + for (int i = 0; i < 3; ++i) { + data->srcStride[i] = srcStride[i]; + data->dstStride[i] = dstStride[i]; + data->size[i] = dataSize[i]; + } + + auto initDstTensor = mTensors[initCmd->indexes()->data()[0]]; + auto initSrcTensor = mTensors[initCmd->indexes()->data()[1]]; + auto initInputSize = initSrcTensor->usize() / initSrcTensor->buffer().type.bytes(); + auto initOutputSize = initDstTensor->usize() / initDstTensor->buffer().type.bytes(); + data->totalSize[0] = initInputSize; + data->totalSize[1] = initOutputSize; + + auto backend = static_cast(this->backend()); + auto context = (__bridge MNNMetalContext *)backend->context(); + mInitThreads = [context computeBestGroupAndLocal:mInitPipeline threads:MTLSizeMake(data->size[0], data->size[1], data->size[2])]; + } return NO_ERROR; } virtual void onEncode(const std::vector& inputs, const std::vector& outputs, id encoder) override { + if(mNeedInit) { + auto initCmd = mLoop->initCommand()->GetAs(0); + int x = initCmd->size()->data()[0]; + int y = initCmd->size()->data()[1]; + int z = initCmd->size()->data()[2]; + + [encoder setComputePipelineState:mInitPipeline]; + auto dstTensor = mTensors[initCmd->indexes()->data()[0]]; + auto srcTensor = mTensors[initCmd->indexes()->data()[1]]; + MetalBackend::setTensor(dstTensor, encoder, 0); + MetalBackend::setTensor(srcTensor, encoder, 1); + [encoder setBuffer:mInitParam offset:0 atIndex:2]; + + [encoder dispatchThreadgroups:mInitThreads.first threadsPerThreadgroup:mInitThreads.second]; + } + auto cmd = mLoop->commands()->GetAs(0); auto size = cmd->size()->data(); auto srcStride = cmd->view()->GetAs(1)->stride()->data(); auto dstStride = cmd->view()->GetAs(0)->stride()->data(); int totalSize = mLoop->loopNumber() * size[0] * size[1] * size[2]; - + [encoder setComputePipelineState:mPipeline]; auto dstTensor = mTensors[cmd->indexes()->data()[0]]; auto srcTensor = mTensors[cmd->indexes()->data()[1]]; MetalBackend::setTensor(dstTensor, encoder, 0); MetalBackend::setTensor(srcTensor, encoder, 1); - + auto iterIndex = cmd->iterIndexes()->data(); if (iterIndex[0] >= 0) { MetalBackend::setTensor(mTensors[iterIndex[0]], encoder, 3); @@ -452,7 +595,7 @@ int computeVec4dot(thread const int4& a, thread const int4& b) return (((a.x * b.x) + (a.y * b.y)) + (a.z * b.z)) + (a.w * b.w); } -kernel void main0(device T1* uOutput [[buffer(0)]], const device T0* uInput0 [[buffer(1)]], const device T0* uInput1 [[buffer(2)]], constant constBuffer& uConstant [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +kernel void loop_binary(device T1* uOutput [[buffer(0)]], const device T0* uInput0 [[buffer(1)]], const device T0* uInput1 [[buffer(2)]], constant constBuffer& uConstant [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) { int3 posTmp = int3(gl_GlobalInvocationID); if (posTmp.x < uConstant.size.w) @@ -515,7 +658,7 @@ public: @"T1" : T1, @"CUSTOM" : CUSTOM, }; - pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryBroadcast, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryBroadcast, "loop_binary", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { @@ -577,9 +720,7 @@ public: if (nullptr == loop || loop->commands() == nullptr) { return nullptr; } - if (nullptr != loop->initCommand()) { - return nullptr; - } + // Make Tensor Stack if (1 == loop->commands()->size()) { auto cmd = loop->commands()->GetAs(0); @@ -587,10 +728,10 @@ public: if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0) { return new MetalGather(loop, bn, inputs, outputs); } - if (OpType_MatMul == subop->type() && loop->parallel()) { + if (OpType_MatMul == subop->type() && loop->parallel() && nullptr == loop->initCommand()) { return new MetalBatchMatMul(loop, bn); } - if (OpType_BinaryOp == subop->type() && cmd->fuse() < 0 && 1 == loop->loopNumber()) { + if (OpType_BinaryOp == subop->type() && cmd->fuse() < 0 && 1 == loop->loopNumber() && nullptr == loop->initCommand()) { std::vector tensors(loop->tensorNumber()); _setTensorStack(tensors, inputs, outputs, loop); auto srcTensor = tensors[cmd->indexes()->data()[1]]; diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp index 66b4577a..50abef15 100644 --- a/source/backend/opencl/core/OpenCLBackend.cpp +++ b/source/backend/opencl/core/OpenCLBackend.cpp @@ -466,7 +466,7 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp } bool OpenCLBackend::onSelectDynamicAllocator(int index, int maxIndex) { - if (mUseRecordQueue && false == mDevideOpRecord){ + if (mUseRecordQueue && false == mDeviceOpRecord){ return false; } if (maxIndex > 2) { @@ -525,7 +525,7 @@ Execution* OpenCLBackend::onCreate(const std::vector& inputs, const std return NULL; } if (iter == creators->end()) { - mDevideOpRecord = true; + mDeviceOpRecord = true; #ifdef OPENCL_FALLBACK_LOG if (nullptr != op->name()) { MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mMemType, op->name()->c_str()); @@ -565,7 +565,7 @@ Execution* OpenCLBackend::onCreate(const std::vector& inputs, const std } if (!valid) { - mDevideOpRecord = true; + mDeviceOpRecord = true; #ifdef OPENCL_FALLBACK_LOG for (auto t : inputs) { auto tensorShape = OpenCL::tensorShapeFormat(t); @@ -589,7 +589,7 @@ Execution* OpenCLBackend::onCreate(const std::vector& inputs, const std auto exe = iter->second->onCreate(inputs, outputs, op, this); if (NULL == exe) { - mDevideOpRecord = true; + mDeviceOpRecord = true; #ifdef OPENCL_FALLBACK_LOG if (nullptr != op->name()) { MNN_PRINT("The Creator Don't support type %s, memObject:%d, %s\n", MNN::EnumNameOpType(op->type()), mMemType, op->name()->c_str()); @@ -1232,7 +1232,7 @@ int OpenCLBackend::fpBytes() { void OpenCLBackend::clearRecord() const{ #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) - if(mUseRecordQueue && mDevideOpRecord){ + if(mUseRecordQueue && mDeviceOpRecord){ for(int i = 0; i < mRecordings.size(); ++i){ std::vector update_kernel_args; std::vector update_global_size; @@ -1263,7 +1263,7 @@ void OpenCLBackend::clearRecord() const{ void OpenCLBackend::enqeueRecord() const{ #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) - if(mUseRecordQueue && !mDevideOpRecord){ + if(mUseRecordQueue && !mDeviceOpRecord){ for(int i = 0; i < mRecordings.size(); ++i){ std::vector update_kernel_args; std::vector update_global_size; @@ -1290,7 +1290,7 @@ void OpenCLBackend::enqeueRecord() const{ void OpenCLBackend::releaseRecord(){ #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) - if(mUseRecordQueue && !mDevideOpRecord){ + if(mUseRecordQueue && !mDeviceOpRecord){ for(int i = 0; i < mRecordings.size(); ++i){ cl_int res = clReleaseRecordingQCOM(mRecordings[i].record); MNN_CHECK_CL_SUCCESS(res, "clReleaseRecordingQCOM"); @@ -1309,7 +1309,7 @@ void OpenCLBackend::startRecord(cl_recording_qcom &recording){ MNN_PRINT("start startRecord !\n"); #endif cl_int res = CL_SUCCESS; - if(mDevideOpRecord){ + if(mDeviceOpRecord){ if(recording != NULL){ clReleaseRecordingQCOM(recording); } @@ -1330,7 +1330,7 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){ #ifdef LOG_VERBOSE MNN_PRINT("start endRecord !\n"); #endif - if(mDevideOpRecord){ + if(mDeviceOpRecord){ cl_int res = CL_SUCCESS; res = clEndRecordingQCOM(recording); MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM"); @@ -1349,7 +1349,7 @@ void OpenCLBackend::endRecord(cl_recording_qcom &recording, bool flag){ } void OpenCLBackend::addRecord(cl_recording_qcom &record, std::vectorupdateInfo){ - if(mDevideOpRecord){ + if(mDeviceOpRecord){ RecordInfo info; info.record = record; for(int i = 0; i < updateInfo.size(); ++i) { @@ -1369,7 +1369,7 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr &kernelW, c MNN_PRINT("start record2dKernel !\n"); #endif cl_int res = CL_SUCCESS; - if(!mDevideOpRecord){ + if(!mDeviceOpRecord){ RecordInfo info; int recordNum = mRecordNums == mUseRecordableQueueSize ? 0 : mRecordNums; if(updateInfo != nullptr){ @@ -1439,7 +1439,7 @@ void OpenCLBackend::recordKernel3d(const std::shared_ptr &kernelW, c for (size_t i = 0; i < 3; ++i) { internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i])); } - if(!mDevideOpRecord){ + if(!mDeviceOpRecord){ RecordInfo info; int recordNum = mRecordNums == mUseRecordableQueueSize ? 0 : mRecordNums; if(updateInfo != nullptr){ @@ -1547,12 +1547,12 @@ void OpenCLBackend::setGpuMode(const int cl_mode_num) { mUseRecordQueue = ((cl_mode_num & MNN_GPU_RECORD_OP) || (cl_mode_num & MNN_GPU_RECORD_BATCH)) && mOpenCLRuntime->isSupportRecordQueue() && (mUseRecordableQueueSize > 0); isSet = (cl_mode_num & MNN_GPU_RECORD_OP); if(isSet) { - mDevideOpRecord = true; + mDeviceOpRecord = true; totalSet++; } isSet = (cl_mode_num & MNN_GPU_RECORD_BATCH); if(isSet) { - mDevideOpRecord = false; + mDeviceOpRecord = false; totalSet++; } if(totalSet > 1) { diff --git a/source/backend/opencl/core/OpenCLBackend.hpp b/source/backend/opencl/core/OpenCLBackend.hpp index 52b1a891..cc31ecf5 100644 --- a/source/backend/opencl/core/OpenCLBackend.hpp +++ b/source/backend/opencl/core/OpenCLBackend.hpp @@ -137,7 +137,7 @@ public: return mUseRecordQueue; } bool isDevideOpRecord(){ - return mDevideOpRecord; + return mDeviceOpRecord; } CLTuneLevel getCLTuneLevel() { return mTuneLevel; @@ -185,7 +185,8 @@ private: bool mIsCreateError{false}; mutable std::vector mRecordings; bool mUseRecordQueue = false; - bool mDevideOpRecord = false; + bool mDeviceOpRecord = false; + friend class setRecordClose; uint32_t mRecordNums = 0; uint32_t mUseRecordableQueueSize; private: @@ -200,6 +201,25 @@ private: }; +class setRecordClose{ +public: + setRecordClose(OpenCLBackend *bn){ + backend = bn; + if(backend->mUseRecordQueue){ + backend->mUseRecordQueue = false; + needRecover = true; + } + } + ~setRecordClose(){ + if(needRecover){ + backend->mUseRecordQueue = true; + } + } +private: + bool needRecover = false; + OpenCLBackend* backend; +}; + template class OpenCLCreatorRegister { public: diff --git a/source/backend/opencl/core/OpenCLRunningUtils.cpp b/source/backend/opencl/core/OpenCLRunningUtils.cpp index 3c98407f..bdbcb6bf 100644 --- a/source/backend/opencl/core/OpenCLRunningUtils.cpp +++ b/source/backend/opencl/core/OpenCLRunningUtils.cpp @@ -633,20 +633,21 @@ bool localWSTune(const std::mappreParamsMap(); - if (preParamInfo.find(preParamName) != preParamInfo.end()) { - *preParamData = preParamInfo[preParamName]; +bool getTunedInfo(const std::string kernelName, const std::vector &gws, std::pair, uint32_t> &tuneInfo, OpenCLRuntime *runtime){ + auto& tunedLws = runtime->tunedLwsMap(); + auto& tuneLws = runtime->getTuneLwsMap(); + std::pair> info = std::make_pair(kernelName, gws); + if (tunedLws.find(info) != tunedLws.end()) { + tuneInfo = tunedLws[info]; return true; } - return false; + return localWSTune(tuneLws, gws, kernelName, tuneInfo); } -void setPreParamInfo(const std::string preParamName, uint32_t preParamData, OpenCLRuntime *runtime){ - auto& preParamInfo = runtime->preParamsMap(); - if (preParamInfo.find(preParamName) == preParamInfo.end()) { - preParamInfo.insert(std::make_pair(preParamName, preParamData)); - } +void setTunedInfo(const std::string kernelName, const std::vector &gws, std::pair, uint32_t> &tuneInfo, OpenCLRuntime *runtime){ + auto& tunedLws = runtime->tunedLwsMap(); + std::pair> info = std::make_pair(kernelName, gws); + tunedLws.insert(std::make_pair(info, std::make_pair(tuneInfo.first, tuneInfo.second))); } } // namespace OpenCL diff --git a/source/backend/opencl/core/OpenCLRunningUtils.hpp b/source/backend/opencl/core/OpenCLRunningUtils.hpp index 780f1e2b..3b34ef72 100644 --- a/source/backend/opencl/core/OpenCLRunningUtils.hpp +++ b/source/backend/opencl/core/OpenCLRunningUtils.hpp @@ -132,9 +132,9 @@ uint32_t get2DUseLocalMemTime(const std::vector &gws, const std::vecto std::pair, uint32_t> localWS2DDefault(const std::vector &gws, const uint32_t maxWorkGroupSize, OpenCLRuntime *runtime, const std::string &kernelName, const std::shared_ptr &mKernel, int tuneLevel); -bool getPreParamInfo(const std::string preParamName, uint32_t *preParamData, OpenCLRuntime *runtime); +bool getTunedInfo(const std::string kernelName, const std::vector &gws, std::pair, uint32_t> &tuneInfo, OpenCLRuntime *runtime); -void setPreParamInfo(const std::string preParamName, uint32_t preParamData, OpenCLRuntime *runtime); +void setTunedInfo(const std::string kernelName, const std::vector &gws, std::pair, uint32_t> &tuneInfo, OpenCLRuntime *runtime); void copyBufferToImage(OpenCLRuntime *runtime, const cl::Buffer &buffer, const cl::Image &image, int w, int h, int precision); diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp index f672d056..6506fec5 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp @@ -348,10 +348,6 @@ unsigned int OpenCLRuntime::getQueueNum() { return mQueueCount; } -std::map& OpenCLRuntime::preParamsMap(){ - return mPreParams; -} - std::map, std::vector>& OpenCLRuntime::tunedGemmParamsMap() { return mTunedGemmParams; } @@ -804,14 +800,6 @@ std::pair OpenCLRuntime::makeCache(void* tuneInfo) { backend->gemm.emplace_back(std::move(tuning)); } - // Get All PreParam cache - for(auto& iter : mPreParams){ - std::unique_ptr info(new PreParamInfoT); - info->preParamName = iter.first; - info->preParamData = iter.second; - backend->preParam.emplace_back(std::move(info)); - } - cache->backends.emplace_back(std::move(backend)); flatbuffers::FlatBufferBuilder builder; @@ -921,23 +909,18 @@ bool OpenCLRuntime::setCache(std::pair cache) { } } } - - //Load PreParam Info - if(nullptr != backendinfo->preParam()){ - auto preParamInfo = backendinfo->preParam(); - for(int i = 0; i < preParamInfo->size(); ++i){ - auto info = preParamInfo->GetAs(i); - if (nullptr == info->preParamName()) { - MNN_ERROR("Error preParam info\n"); - return false; - } - mPreParams.insert(std::make_pair(info->preParamName()->str(), info->preParamData())); - } - } } return true; } +unsigned int OpenCLRuntime::getEventTime(cl::Event& event){ + cl_int res = event.wait(); + MNN_CHECK_CL_SUCCESS(res, "clEvent"); + auto StartNanos = event.getProfilingInfo(); + auto StopNanos = event.getProfilingInfo(); + return (unsigned int)((StopNanos - StartNanos) / 1000.0); +} + void OpenCLRuntime::printEventTime(){ #ifdef ENABLE_OPENCL_TIME_PROFILER if(mEvents.empty()){ diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.hpp b/source/backend/opencl/core/runtime/OpenCLRuntime.hpp index 44602856..37672131 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.hpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.hpp @@ -128,6 +128,7 @@ public: void pushEvent(std::pair data) { return mEvents.push_back(data); } + unsigned int getEventTime(cl::Event& event); void printEventTime(); void clearEvent(){ mKernelTime = 0; @@ -143,8 +144,6 @@ public: unsigned int mKernelTime = 0; - std::map& preParamsMap(); - std::map, std::vector>& tunedGemmParamsMap(); std::map>, std::pair, uint32_t>>& tunedLwsMap(); @@ -232,7 +231,6 @@ private: double mStartNanos; double mStopNanos; - std::map mPreParams; std::map, std::vector> mTunedGemmParams; std::map>, std::pair, uint32_t>> mTunedLws; std::map, std::pair, uint32_t>>>> mTuneLws; diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp index 100a813a..80e7e9b4 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp @@ -17,7 +17,7 @@ KVCacheCLManager::KVCacheCLManager(Backend *backend, bool kv_cahce) : mKVCache(k mOpenCLBackend = static_cast(backend); } -void KVCacheCLManager::allocKVCache(const KVMeta* meta, bool isDecodeResize) { +void KVCacheCLManager::allocKVCache(const KVMeta* meta) { if (!mKVCache) { return; } @@ -25,10 +25,10 @@ void KVCacheCLManager::allocKVCache(const KVMeta* meta, bool isDecodeResize) { if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ mByte = 2; } - reallocKVCache(meta, isDecodeResize); + reallocKVCache(meta, false); } -bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, bool isDecodeResize) { +bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, bool isExecute) { if (!mKVCache) { return false; } @@ -91,14 +91,14 @@ bool KVCacheCLManager::reallocKVCache(const KVMeta* meta, bool isDecodeResize) { mPastKey.reset(newKey); mPastValue.reset(newValue); // resize phase don't update mPastLength value, excute phase will update it - if(false == isDecodeResize){ + if(isExecute){ mPastLength = start; } } // Remove // resize phase don't remove kvcache, excute phase will do it - if(false == isDecodeResize){ + if(isExecute){ if (0 == meta->n_reserve) { mPastLength = start; return true; @@ -161,7 +161,7 @@ void AttentionBufExecution::handleKVCache(const std::vector &inputs, c int mask_seqlen = mask_shape[2]; int maskKvlen = mask_shape[3]; mKVCacheCLManager->setArgs(numHead, kvNumHead, headDim); - mKVCacheCLManager->allocKVCache(mMeta, mIsDecode); + mKVCacheCLManager->allocKVCache(mMeta); mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4); mDecodeTmpMaxlen = mKeyValueMaxlen; mPastKvSeqlen = mKVCacheCLManager->pastKvLength(); @@ -177,6 +177,9 @@ ErrorCode AttentionBufExecution::init() { mRgQUpdateInfo.update_kernel_args.clear(); mRgQUpdateInfo.update_global_size.clear(); mRgQUpdateInfo.update_local_size.clear(); + mRgMUpdateInfo.update_kernel_args.clear(); + mRgMUpdateInfo.update_global_size.clear(); + mRgMUpdateInfo.update_local_size.clear(); mRgUpdateInfo.update_kernel_args.clear(); mRgUpdateInfo.update_global_size.clear(); mRgUpdateInfo.update_local_size.clear(); @@ -222,9 +225,36 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, mKVCacheCLManager->addKvLength(seqlen); // prefill if(mIsDecode == false){ + int maskKvlen = mKvSeqlen; + int maskQlen = seqlen; + if(mHasMask) { + auto mask = inputs[3]; + auto mask_shape = mask->shape(); + maskQlen = mask_shape[2]; + maskKvlen = mask_shape[3]; + } // key value static memory has been changed, need reset args if(mKeyValueMaxlen != ROUND_UP(mKVCacheCLManager->maxLength(), 4)){ mKeyValueMaxlen = ROUND_UP(mKVCacheCLManager->maxLength(), 4); + } + if(false == mLongPrefill){ + mGlobalWorkSizeQk0 = UP_DIV(mKvSeqlen, 4); + mQkPrefillGlobal_size[1] = ROUND_UP(mGlobalWorkSizeQk0, std::max((uint32_t)1, mLocalWorkSizeQk[1])); + mGlobalWorkSizeQk[1] = mQkPrefillGlobal_size[1]; + mTempQK.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch})); + mTempSoftMax.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch})); + if(mIsAddMask) { + mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + } else { + mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + } + mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); + } #ifndef ENABLE_OPENCL_TIME_PROFILER if(mOpenCLBackend->isUseRecordQueue()){ if(mLongPrefill){ @@ -233,10 +263,22 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))(); mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))(); }else{ + mRgQUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQ.get())(); mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))(); - mQkUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))(); + mRgMUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempMask.get())(); + mQkUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempQ.get())(); + mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mKVCacheCLManager->key()))(); + if(mHasMask){ + mQkUpdateInfo.update_kernel_args[3].arg_value = &openCLDeferBuffer(mTempMask.get())(); + mQkUpdateInfo.update_kernel_args[4].arg_value = &openCLDeferBuffer(mTempQK.get())(); + }else{ + mQkUpdateInfo.update_kernel_args[3].arg_value = &openCLDeferBuffer(mTempQK.get())(); + } + mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQK.get())(); + mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempSoftMax.get())(); mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))(); - mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))(); + mQkvUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempSoftMax.get())(); + mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))(); } } else { #endif @@ -256,40 +298,58 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector &inputs, ret |= mKernel_rearrange->get().setArg(6, mKeyValueMaxlen); MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k"); } - { - // matmul qk - cl_int ret = CL_SUCCESS; - ret |= mKernel_qk->get().setArg(4, *mKVCacheCLManager->key()); - ret |= mKernel_qk->get().setArg(10, mKvSeqlen); - ret |= mKernel_qk->get().setArg(11, mKeyValueMaxlen); - MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qk_decode"); - } - { - // softmax - cl_int ret = CL_SUCCESS; - ret |= mKernel_qk->get().setArg(7, mKvSeqlen); - MNN_CHECK_CL_SUCCESS(ret, "reSetArg softmax"); - } - { - cl_int ret = CL_SUCCESS; - ret |= mKernel_rearrangeV->get().setArg(4, *mKVCacheCLManager->value()); - ret |= mKernel_rearrangeV->get().setArg(5, mPastKvSeqlen); - ret |= mKernel_rearrangeV->get().setArg(6, mKeyValueMaxlen); - MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_v"); - } - // qk * value - { - cl_int ret = CL_SUCCESS; - ret |= mKernel_qkv->get().setArg(4, *mKVCacheCLManager->value()); - ret |= mKernel_qkv->get().setArg(7, mKvSeqlen); - ret |= mKernel_qkv->get().setArg(8, mKeyValueMaxlen); - MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qkv_decode"); - } + if(mHasMask){ + // rearrange mask + cl_int ret = CL_SUCCESS; + ret |= mKernel_rearrangeMask->get().setArg(4, openCLDeferBuffer(mTempMask.get())); + MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_mask_shortprefill"); + } + { + // matmul qk + mGlobalWorkSizeQk = {static_cast(UP_DIV(seqlen, 4)), static_cast(UP_DIV(mKvSeqlen, 4)), static_cast(numHead*batch)}; + cl_int ret = CL_SUCCESS; + ret |= mKernel_qk->get().setArg(1, mGlobalWorkSizeQk0); + ret |= mKernel_qk->get().setArg(4, *mKVCacheCLManager->key()); + if(mHasMask) { + ret |= mKernel_qk->get().setArg(5, openCLDeferBuffer(mTempMask.get())); + } + ret |= mKernel_qk->get().setArg(6, openCLDeferBuffer(mTempQK.get())); + ret |= mKernel_qk->get().setArg(10, mKvSeqlen); + ret |= mKernel_qk->get().setArg(11, mKeyValueMaxlen); + MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qk_decode"); + mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0])); + mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1])); + mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2])); + } + { + // softmax + cl_int ret = CL_SUCCESS; + ret |= mKernel_softmax->get().setArg(3, openCLDeferBuffer(mTempQK.get())); + ret |= mKernel_softmax->get().setArg(4, openCLDeferBuffer(mTempSoftMax.get())); + ret |= mKernel_softmax->get().setArg(7, mKvSeqlen); + MNN_CHECK_CL_SUCCESS(ret, "reSetArg softmax"); + } + { + // rearrange value + cl_int ret = CL_SUCCESS; + ret |= mKernel_rearrangeV->get().setArg(4, *mKVCacheCLManager->value()); + ret |= mKernel_rearrangeV->get().setArg(5, mPastKvSeqlen); + ret |= mKernel_rearrangeV->get().setArg(6, mKeyValueMaxlen); + MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_v"); + } + // qk * value + { + cl_int ret = CL_SUCCESS; + ret |= mKernel_qkv->get().setArg(3, openCLDeferBuffer(mTempSoftMax.get())); + ret |= mKernel_qkv->get().setArg(4, *mKVCacheCLManager->value()); + ret |= mKernel_qkv->get().setArg(7, mKvSeqlen); + ret |= mKernel_qkv->get().setArg(8, mKeyValueMaxlen); + MNN_CHECK_CL_SUCCESS(ret, "reSetArg matmul_qkv_decode"); } - #ifndef ENABLE_OPENCL_TIME_PROFILER } - #endif + #ifndef ENABLE_OPENCL_TIME_PROFILER } + #endif return NO_ERROR; } @@ -882,20 +942,30 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu float scale = 1.0 / sqrt(headDim); int maskKvlen = mKvSeqlen; + int maskQlen = seqlen; if(mHasMask) { auto mask = inputs[3]; auto mask_shape = mask->shape(); + maskQlen = mask_shape[2]; maskKvlen = mask_shape[3]; + if(mIsAddMask) { + mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + } else { + mTempMask.reset(Tensor::createDevice({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch})); + } } mTempQ.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch})); mTempQK.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch})); mTempSoftMax.reset(Tensor::createDevice({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch})); - mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); - mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC); - mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION); + if(mHasMask){ + mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + } cl::Buffer keyBuffer, valueBuffer; if(mNeedKvCache) { @@ -911,9 +981,12 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu keyBuffer = openCLBuffer(mTempK.get()); valueBuffer = openCLBuffer(mTempV.get()); } - mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC); - mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC); - mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION); + mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION); + if(mHasMask){ + mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION); + } { // rearrange query @@ -932,7 +1005,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_rearrangeQ->get().setArg(index++, mGlobalWorkSizeRearrgQ[1]); ret |= mKernel_rearrangeQ->get().setArg(index++, mGlobalWorkSizeRearrgQ[2]); ret |= mKernel_rearrangeQ->get().setArg(index++, openCLBuffer(query)); - ret |= mKernel_rearrangeQ->get().setArg(index++, openCLBuffer(mTempQ.get())); + ret |= mKernel_rearrangeQ->get().setArg(index++, openCLDeferBuffer(mTempQ.get())); ret |= mKernel_rearrangeQ->get().setArg(index++, seqlen); ret |= mKernel_rearrangeQ->get().setArg(index++, headDim); ret |= mKernel_rearrangeQ->get().setArg(index++, numHead); @@ -942,8 +1015,9 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mGlobalWorkSizeRearrgQ[0] = ROUND_UP(mGlobalWorkSizeRearrgQ[0], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[0])); mGlobalWorkSizeRearrgQ[1] = ROUND_UP(mGlobalWorkSizeRearrgQ[1], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[1])); mGlobalWorkSizeRearrgQ[2] = ROUND_UP(mGlobalWorkSizeRearrgQ[2], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[2])); + mRgQUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempQ.get())()}); mOpRecordUpdateInfo.emplace_back(&mRgQUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ); + mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, &mRgQUpdateInfo); } { // rearrange key @@ -984,6 +1058,34 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo); mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, &mRgUpdateInfo); } + if (mHasMask){ + std::set buildOption; + if(mIsAddMask){ + buildOption.emplace("-DADD_MASK"); + } else if(mHasMask) { + buildOption.emplace("-DSET_MASK"); + } + mKernel_rearrangeMask = runtime->buildKernel("attention_buf", "rearrange_mask_shortprefill", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]); + mGlobalWorkSizeRearrgM = {static_cast(UP_DIV(maskQlen, 4)), static_cast(UP_DIV(maskKvlen, 4)), static_cast(batch)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrangeMask)); + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[0]); + ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[1]); + ret |= mKernel_rearrangeMask->get().setArg(index++, mGlobalWorkSizeRearrgM[2]); + ret |= mKernel_rearrangeMask->get().setArg(index++, openCLBuffer(inputs[3])); + ret |= mKernel_rearrangeMask->get().setArg(index++, openCLDeferBuffer(mTempMask.get())); + ret |= mKernel_rearrangeMask->get().setArg(index++, maskQlen); + ret |= mKernel_rearrangeMask->get().setArg(index++, maskKvlen); + MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask_shortprefill"); + mLocalWorkSizeRearrgM = localWS3DDefault(mGlobalWorkSizeRearrgM, maxWorkGroupSize, runtime, "rearrange_mask_shortprefill", mKernel_rearrangeMask, mOpenCLBackend->getCLTuneLevel()).first; + mGlobalWorkSizeRearrgM[0] = ROUND_UP(mGlobalWorkSizeRearrgM[0], std::max((uint32_t)1, mLocalWorkSizeRearrgM[0])); + mGlobalWorkSizeRearrgM[1] = ROUND_UP(mGlobalWorkSizeRearrgM[1], std::max((uint32_t)1, mLocalWorkSizeRearrgM[1])); + mGlobalWorkSizeRearrgM[2] = ROUND_UP(mGlobalWorkSizeRearrgM[2], std::max((uint32_t)1, mLocalWorkSizeRearrgM[2])); + mRgMUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempMask.get())()}); + mOpRecordUpdateInfo.emplace_back(&mRgMUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, &mRgMUpdateInfo); + } { // matmul qk std::set buildOption; @@ -1002,12 +1104,12 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]); ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]); ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[2]); - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mTempQ.get())); + ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQ.get())); ret |= mKernel_qk->get().setArg(index++, keyBuffer); if(mHasMask) { - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(inputs[3])); + ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempMask.get())); } - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mTempQK.get())); + ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQK.get())); ret |= mKernel_qk->get().setArg(index++, scale); ret |= mKernel_qk->get().setArg(index++, seqlen); ret |= mKernel_qk->get().setArg(index++, maskKvlen); @@ -1021,16 +1123,25 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0])); mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1])); mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2])); + mQkUpdateInfo.update_kernel_args.push_back({0, 1, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0}); + mQkUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQ.get())()}); if(mNeedKvCache) { mQkUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); } if(mHasMask){ + mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLDeferBuffer(mTempMask.get())()}); + mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKvSeqlen), &mKvSeqlen}); mQkUpdateInfo.update_kernel_args.push_back({0, 11, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); }else{ + mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); mQkUpdateInfo.update_kernel_args.push_back({0, 9, sizeof(mKvSeqlen), &mKvSeqlen}); mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen}); } + mQkPrefillGlobal_size[0] = mGlobalWorkSizeQk[0]; + mQkPrefillGlobal_size[1] = mGlobalWorkSizeQk[1]; + mQkPrefillGlobal_size[2] = mGlobalWorkSizeQk[2]; + mQkUpdateInfo.update_global_size.push_back({0, mQkPrefillGlobal_size}); mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo); mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo); } @@ -1051,8 +1162,8 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]); ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]); ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]); - ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempQK.get())); - ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); + ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempQK.get())); + ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get())); ret |= mKernel_softmax->get().setArg(index++, inside); ret |= mKernel_softmax->get().setArg(index++, outside); ret |= mKernel_softmax->get().setArg(index++, mKvSeqlen); @@ -1065,9 +1176,11 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); + mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()}); + mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen}); mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax); + mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo); } { // rearrange value @@ -1120,7 +1233,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]); ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]); ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]); - ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); + ret |= mKernel_qkv->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get())); ret |= mKernel_qkv->get().setArg(index++, valueBuffer); ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0])); ret |= mKernel_qkv->get().setArg(index++, seqlen); @@ -1135,6 +1248,7 @@ ErrorCode AttentionBufExecution::prefillResize(const std::vector &inpu mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0])); mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1])); mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2])); + mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()}); if(mNeedKvCache) { mQkvUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); } @@ -1164,15 +1278,7 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input int group_size = numHead / kvNumHead; float scale = 1.0 / sqrt(headDim); - int mask_seqlen = seqlen; - int mask_kvlen = seqlen; - if(mHasMask) { - auto mask = inputs[3]; - auto mask_shape = mask->shape(); - mask_seqlen = mask_shape[2]; - mask_kvlen = mask_shape[3]; - } cl::Buffer keyBuffer, valueBuffer; if(mNeedKvCache) { keyBuffer = *mKVCacheCLManager->key(); @@ -1440,13 +1546,17 @@ ErrorCode AttentionBufExecution::decodeResize(const std::vector &input ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, const std::vector &outputs) { mOpenCLBackend->startRecord(mRecording); auto shape = inputs[0]->shape(); + int batch = shape[0]; int seqlen = shape[1]; + int numHead = shape[2]; + int headDim = shape[3]; + int kvNumHead = inputs[1]->shape()[2]; if(mNeedKvCache) { // if has kv_cache, default has mask MNN_ASSERT(inputs.size() > 3); } mHasMask = inputs.size() > 3; - mIsDecode = seqlen == 1; + mIsDecode = seqlen == 1 && mMeta->add == 1; // reset updateArgs variable and kernel vector init(); @@ -1457,22 +1567,99 @@ ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, c if(mIsDecode) { return decodeResize(inputs, outputs); } else { - if(seqlen > 512 && mPastKvSeqlen == 0){ - mLongPrefill = true; - return longPrefillResize(inputs, outputs); + if(mPastKvSeqlen == 0){ + std::pair, uint32_t> tuneInfo; + std::string info = "attention_" + std::to_string(batch) + "_" + std::to_string(numHead) + "_" + std::to_string(headDim) + "_" + std::to_string(kvNumHead); + if(seqlen > 16){ + if(getTunedInfo(info, {static_cast(seqlen)}, tuneInfo, mOpenCLBackend->getOpenCLRuntime())){ + mLongPrefill = tuneInfo.first[0]; + } else{ + if (mOpenCLBackend->getCLTuneLevel() == Heavy || mOpenCLBackend->getCLTuneLevel() == Wide){ + setRecordClose closeRecord(mOpenCLBackend); + // tunning choose use witch preill + prefillResize(inputs, outputs); + auto shortPrefillTime = getExecuteTime(); + init(); + mLongPrefill = true; + longPrefillResize(inputs, outputs); + auto longPrefillTime = getExecuteTime(); + mLongPrefill = false; + if(longPrefillTime < shortPrefillTime){ + mLongPrefill = true; + } + std::pair, uint32_t> tuneInfoTmp = std::make_pair, uint32_t>({mLongPrefill}, 0); + setTunedInfo(info, {static_cast(seqlen)}, tuneInfoTmp, mOpenCLBackend->getOpenCLRuntime()); + init(); + }else{ + if(seqlen > 512){ + mLongPrefill = true; + } + } + } + } + } + if(mLongPrefill){ + longPrefillResize(inputs, outputs); }else{ - return prefillResize(inputs, outputs); + prefillResize(inputs, outputs); } } return NO_ERROR; } +int AttentionBufExecution::getExecuteTime(){ + int executeTime = 0; + auto runtime = mOpenCLBackend->getOpenCLRuntime(); + if(mLongPrefill) { + int seq_idx = 0; + cl::Event event0, event1, event2, event3, event4, event5, event6; + run3DKernelDefault(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event0); + executeTime += runtime->getEventTime(event0); + if(mHasMask) { + run3DKernelDefault(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event1); + executeTime += runtime->getEventTime(event1); + } + for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { + run3DKernelDefault(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event2); + executeTime += runtime->getEventTime(event2); + run3DKernelDefault(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event3); + executeTime += runtime->getEventTime(event3); + run3DKernelDefault(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event4); + executeTime += runtime->getEventTime(event4); + run3DKernelDefault(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event5); + executeTime += runtime->getEventTime(event5); + } + seq_idx = 0; + run3DKernelDefault(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx], mOpenCLBackend->getOpenCLRuntime(), &event6); + executeTime += runtime->getEventTime(event6); + } else{ + cl::Event event0, event1, event2, event3, event4, event5, event6; + run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime(), &event0); + executeTime += runtime->getEventTime(event0); + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event1); + executeTime += runtime->getEventTime(event1); + if(mHasMask) { + run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime(), &event2); + executeTime += runtime->getEventTime(event2); + } + run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event3); + executeTime += runtime->getEventTime(event3); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event4); + executeTime += runtime->getEventTime(event4); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event5); + executeTime += runtime->getEventTime(event5); + run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event6); + executeTime += runtime->getEventTime(event6); + } + return executeTime; +} + ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { #ifdef LOG_VERBOSE MNN_PRINT("start AttentionBufExecution onExecute !\n"); #endif - if(mNeedKvCache && mIsDecode){ + if(mNeedKvCache){ mKVCacheCLManager->reallocKVCache(mMeta); } UpdateArgs(inputs, outputs); @@ -1513,19 +1700,23 @@ ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, runKernel2D(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event4); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event4}); }else{ - cl::Event event0, event1, event2, event3, event4, event5; + cl::Event event0, event1, event2, event3, event4, event5, event6; run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime(), &event0); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_q", event0}); run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event1); mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_k", event1}); - run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event2); - mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event2}); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event3); - mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event3}); - run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event4); - mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_v", event4}); - run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event5); - mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event5}); + if(mHasMask) { + run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime(), &event2); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_mask_shortprefill", event2}); + } + run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime(), &event3); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event3}); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime(), &event4); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event4}); + run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime(), &event5); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_v", event5}); + run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime(), &event6); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event6}); } } #else @@ -1562,6 +1753,9 @@ ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, }else{ run3DKernelDefault(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ, mOpenCLBackend->getOpenCLRuntime()); run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime()); + if(mHasMask) { + run3DKernelDefault(mKernel_rearrangeMask, mGlobalWorkSizeRearrgM, mLocalWorkSizeRearrgM, mOpenCLBackend->getOpenCLRuntime()); + } run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime()); run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime()); run3DKernelDefault(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, mOpenCLBackend->getOpenCLRuntime()); diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp index 79621836..b89037f4 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp @@ -22,8 +22,8 @@ public: KVCacheCLManager(Backend *backend, bool kv_cache); ~KVCacheCLManager() = default; - void allocKVCache(const KVMeta* meta, bool isDecodeResize = false); - bool reallocKVCache(const KVMeta* meta, bool isDecodeResize = false); + void allocKVCache(const KVMeta* meta); + bool reallocKVCache(const KVMeta* meta, bool isExecute = true); void setArgs(int numHead, int kvNumHead, int headDim){ mNumHead = numHead; mKvNumHead = kvNumHead; @@ -67,6 +67,7 @@ public: ErrorCode UpdateArgs(const std::vector &inputs, const std::vector &outputs); ErrorCode init(); + int getExecuteTime(); virtual ~AttentionBufExecution() = default; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; @@ -88,12 +89,14 @@ private: OpenCLBackend *mOpenCLBackend; RecordUpdateInfo mRgUpdateInfo; RecordUpdateInfo mRgQUpdateInfo; + RecordUpdateInfo mRgMUpdateInfo; RecordUpdateInfo mQkUpdateInfo; RecordUpdateInfo mSoftMaxUpdateInfo; RecordUpdateInfo mRgVUpdateInfo; RecordUpdateInfo mQkvUpdateInfo; int mGlobalWorkSizeQk0 = 0; size_t mQkGlobal_size[2]; + size_t mQkPrefillGlobal_size[3]; std::vector mOpRecordUpdateInfo; std::shared_ptr mKVCacheCLManager; std::shared_ptr mTempQK, mTempSoftMax; @@ -131,6 +134,7 @@ private: private: std::shared_ptr mKernel_rearrangeQ; std::shared_ptr mKernel_rearrangeV; + std::shared_ptr mKernel_rearrangeMask; std::shared_ptr mKernel_rearrange; std::shared_ptr mKernel_qk; std::shared_ptr mKernel_softmax; @@ -148,6 +152,8 @@ private: std::vector mLocalWorkSizeRearrgV; std::vector mGlobalWorkSizeRearrg; std::vector mLocalWorkSizeRearrg; + std::vector mGlobalWorkSizeRearrgM; + std::vector mLocalWorkSizeRearrgM; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp index c7f6739a..a21877af 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp @@ -614,15 +614,12 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu const int blockDim = mResource->mInputChannel / mResource->mBlockSize; bool useLocalMem = inputChannels >= 32; std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel); - std::string kernelName = "gemv_conv_c8"; std::set buildOption = mResource->mBuildOptions; int inputChannelLeaves = 0; if(mResource->mNumQuantBit == 4){ inputChannelLeaves = useLocalMem ? (inputChannels % 4) : (blockDim % 4); - kernelName += "_int4_buf"; } else { inputChannelLeaves = useLocalMem ? (inputChannels % 2) : (blockDim % 2); - kernelName += "_int8_buf"; } if(outChannel % 8 != 0){ buildOption.emplace("-DOUTPUT_CHANNEL_LEAVES"); @@ -638,7 +635,7 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu for (int ksize = 8; ksize <= 256; ksize*=2) { auto option = buildOption; option.emplace("-DWGS=" + std::to_string(ksize)); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName, option, mOpenCLBackend->getPrecision()); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); std::vector gws = {static_cast(ksize), static_cast(UP_DIV(outChannel, 8))}; std::vector lws = {static_cast(ksize), 1}; @@ -646,6 +643,7 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu cl_int ret = CL_SUCCESS; ret |= kernel->get().setArg(idx++, static_cast(gws[0])); ret |= kernel->get().setArg(idx++, static_cast(gws[1])); + ret |= kernel->get().setArg(idx++, static_cast(gws[1])); ret |= kernel->get().setArg(idx++, openCLBuffer(input)); if(mResource->mUseImage){ ret |= kernel->get().setArg(idx++, *mResource->mKernelImage.get()); @@ -657,13 +655,15 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu ret |= kernel->get().setArg(idx++, openCLBuffer(output)); ret |= kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); ret |= kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); + ret |= kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); ret |= kernel->get().setArg(idx++, inputChannels); ret |= kernel->get().setArg(idx++, static_cast(blockNum)); ret |= kernel->get().setArg(idx++, static_cast(blockDim)); ret |= kernel->get().setArg(idx++, static_cast(mResource->mCoef)); - MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf Kernel Select"); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf Kernel Select"); std::pair, int> retTune; - int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), kernelName + info, kernel); + int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info, kernel); if(min_time > cost_time) { local_size = ksize; min_time = cost_time; @@ -673,12 +673,13 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu buildOption.emplace("-DWGS=" + std::to_string(local_size)); mGlobalWorkSize = {static_cast(local_size), static_cast(UP_DIV(outChannel, 8))}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName, buildOption, mOpenCLBackend->getPrecision()); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); if(mResource->mUseImage){ ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); @@ -690,15 +691,17 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); ret |= unit.kernel->get().setArg(idx++, static_cast(blockNum)); ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); ret |= unit.kernel->get().setArg(idx++, static_cast(mResource->mCoef)); - MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c4_0_buf"); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf"); if(useLocalMem){ mLocalWorkSize = {static_cast(local_size), 1}; }else{ - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf", unit.kernel, mOpenCLBackend->getCLTuneLevel()).first; + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info, unit.kernel, mOpenCLBackend->getCLTuneLevel()).first; } mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; @@ -747,8 +750,142 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, option, mOpenCLBackend->getPrecision()); } std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel); - - + if(global_y <= 16) { + mUnits.resize(3); + int outputChannelAlign8 = ROUND_UP(outChannel, 8); + mConvGemmInpTensor.reset(Tensor::createDevice({inputChannelAlign * ROUND_UP(global_y, 4)})); + mConvGemmOutTensor.reset(Tensor::createDevice({outputChannelAlign8 * ROUND_UP(global_y, 4)})); + mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + + { + //c4nhw4 -> nhwc + auto &unit = mUnits[0]; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_c4nhw4_to_nhwc", buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + + mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), static_cast(UP_DIV(inputChannels, 4))}; + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); + ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelAlign)); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_c4nhw4_to_nhwc"); + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemm_c4nhw4_to_nhwc", unit.kernel, mOpenCLBackend->getCLTuneLevel()).first; + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + } + { + const int inputChannelBlocks = UP_DIV(inputChannels, 4); + const int outputChannelBlocks = UP_DIV(outChannel, 4); + auto &unit = mUnits[1]; + std::set buildOption = mResource->mBuildOptions; + if(mResource->mUseImage){ + buildOption.emplace("-DUSE_IMAGE"); + } + buildOption.emplace("-DCOMPUTE_BATCH"); + + int local_size = 64; + if(mOpenCLBackend->getCLTuneLevel() != None && mOpenCLBackend->getCLTuneLevel() != Fast){ + int min_time = INT_MAX; + for (int ksize = 16; ksize <= 256; ksize*=2) { + auto option = buildOption; + option.emplace("-DWGS=" + std::to_string(ksize)); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", option, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); + std::vector gws = {static_cast(ksize), static_cast(UP_DIV(outChannel, 8)), static_cast(UP_DIV(global_y, 4))}; + std::vector lws = {static_cast(ksize), 1, 1}; + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= kernel->get().setArg(idx++, static_cast(gws[0])); + ret |= kernel->get().setArg(idx++, static_cast(gws[1])); + ret |= kernel->get().setArg(idx++, static_cast(gws[2])); + ret |= kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); + if(mResource->mUseImage){ + ret |= kernel->get().setArg(idx++, *mResource->mKernelImage.get()); + }else{ + ret |= kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); + } + ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); + ret |= kernel->get().setArg(idx++, static_cast(outputChannelAlign8)); + ret |= kernel->get().setArg(idx++, static_cast(inputChannelAlign)); + ret |= kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); + ret |= kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= kernel->get().setArg(idx++, inputChannels); + ret |= kernel->get().setArg(idx++, static_cast(blockNum)); + ret |= kernel->get().setArg(idx++, static_cast(blockDim)); + ret |= kernel->get().setArg(idx++, static_cast(mResource->mCoef)); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf Kernel Select"); + std::pair, int> retTune; + int cost_time = get2DUseLocalMemTime(gws, lws, mOpenCLBackend->getOpenCLRuntime(), "gemv_conv_c8_buf" + info + "_batch", kernel); + if(min_time > cost_time) { + local_size = ksize; + min_time = cost_time; + } + } + } + buildOption.emplace("-DWGS=" + std::to_string(local_size)); + mGlobalWorkSize = {static_cast(local_size), static_cast(UP_DIV(outChannel, 8)), static_cast(UP_DIV(global_y, 4))}; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", "gemv_conv_c8_buf", buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[2])); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); + if(mResource->mUseImage){ + ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); + }else{ + ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); + } + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign8)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelAlign)); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + ret |= unit.kernel->get().setArg(idx++, static_cast(blockNum)); + ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); + ret |= unit.kernel->get().setArg(idx++, static_cast(mResource->mCoef)); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv_c8_buf"); + mLocalWorkSize = {static_cast(local_size), 1, 1}; + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + } + { + auto &unit = mUnits[2]; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "gemm_nhwc_to_c4nhw4", buildOption, mOpenCLBackend->getPrecision()); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + mGlobalWorkSize = {static_cast(UP_DIV(global_y, 4)), static_cast(UP_DIV(outChannel, 4))}; + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign8)); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_nhwc_to_c4nhw4"); + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "gemm_nhwc_to_c4nhw4", unit.kernel, mOpenCLBackend->getCLTuneLevel()).first; + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + } + return; + } unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName, buildOption, mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); @@ -773,7 +910,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); ret |= unit.kernel->get().setArg(idx++, mResource->mCoef); MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv1x1_buf"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel()).first; + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName + info, unit.kernel, mOpenCLBackend->getCLTuneLevel()).first; mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; @@ -814,13 +951,7 @@ ConvBufLowMemoryExecution::ConvBufLowMemoryExecution(const std::vector } else if (conv2dCommonParams->relu6()) { mResource->mBuildOptions.emplace("-DRELU6"); } - if (mResource->mNumQuantBit == 8) { - // int8 case - mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); - } else if (mResource->mNumQuantBit == 4){ - // int4 case - mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); - } else {/* More types to be supported. */} + mResource->mBuildOptions.emplace("-DQUANT_BIT=" + std::to_string(mResource->mNumQuantBit)); #ifdef LOG_VERBOSE MNN_PRINT("end ConvBufLowMemoryExecution init !\n"); #endif @@ -870,11 +1001,35 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &input if(batch == 1){ tuneGemvLowMemory(input, output); } else { - // when batch is big, convert to float weight and do gemm computation in floating field - if(batch > 128){ + std::pair, uint32_t> tuneInfo; + std::string info = "convBufLowMemory_" + std::to_string(mResource->mInputChannel) + "_" + std::to_string(mResource->mOutputChannel); + if(batch > 16){ + if(getTunedInfo(info, {static_cast(batch)}, tuneInfo, mOpenCLBackend->getOpenCLRuntime())){ + mUseFPWeight = tuneInfo.first[0]; + } else{ + if((mOpenCLBackend->getCLTuneLevel() == Heavy || mOpenCLBackend->getCLTuneLevel() == Wide)){ + setRecordClose closeRecord(mOpenCLBackend); + tuneGemmLowMemory(input, output); + auto shortBatchTime = getExecuteTime(); + mUseFPWeight = true; + useFPWeightGemmLowMemory(input, output); + auto longBatchTime = getExecuteTime(); + mUseFPWeight = false; + if(longBatchTime < shortBatchTime){ + mUseFPWeight = true; + } + std::pair, uint32_t> tuneInfoTmp = std::make_pair, uint32_t>({mUseFPWeight}, 0); + setTunedInfo(info, {static_cast(batch)}, tuneInfoTmp, mOpenCLBackend->getOpenCLRuntime()); + } else{ + if(batch > 512){ + mUseFPWeight = true; + } + } + } + } + if(mUseFPWeight){ useFPWeightGemmLowMemory(input, output); - mUseFPWeight = true; - } else { + }else{ tuneGemmLowMemory(input, output); } } @@ -900,6 +1055,66 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &input return NO_ERROR; } +int ConvBufLowMemoryExecution::getExecuteTime(){ + for (auto &unit : mUnits) { + bool lws_null = true; + for (size_t i = 0; i < unit.globalWorkSize.dimensions(); ++i) { + unit.globalWorkSize.get()[i] = ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i])); + if(unit.localWorkSize.get()[i] != 0) { + lws_null = false; + } + } + if(lws_null){ + unit.localWorkSize = cl::NullRange; + } + } + int executeTime = 0; + auto runtime = mOpenCLBackend->getOpenCLRuntime(); + auto res = CL_SUCCESS; + if(mUseFPWeight){ + // arrange input and weight + int i = 0; + for (; i < 2; ++i){ + auto unit = mUnits[i]; + cl::Event event; + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + executeTime += runtime->getEventTime(event); + } + // call gemm execute + executeTime += mStrassenComputor->getExecuteTime(); + + // rearrange output + for (; i < mUnits.size(); ++i){ + auto unit = mUnits[i]; + cl::Event event; + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + executeTime += runtime->getEventTime(event); + } + }else{ + for (auto &unit : mUnits) { + cl::Event event; + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + executeTime += runtime->getEventTime(event); + } + } + return executeTime; +} + ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { #ifdef LOG_VERBOSE MNN_PRINT("Start ConvBufLowMemoryExecution onExecute !\n"); diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp index ffdf04b8..9e68c563 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp @@ -25,6 +25,7 @@ public: virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; private: + int getExecuteTime(); void getInfoFromOpLowMemory(void *weight_ptr); void set1x1WeightLowMemory(); void setGeneralWeightLowMemory(); diff --git a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp index 6c81dbe6..d040d9ce 100644 --- a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp @@ -283,7 +283,7 @@ LoopBinaryBufExecution::LoopBinaryBufExecution(const LoopParam *loop, const std: : CommonExecution(bn, op) { mLoop = loop; mTensors.resize(mLoop->tensorNumber()); - mBuildOptions.emplace("-DLOOP_BINARY_OPERATOR=" + compute); + mBuildOptions.emplace("-DOPERATOR=" + compute); } ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { diff --git a/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp b/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp index 48b4f98b..2f594887 100644 --- a/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp +++ b/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp @@ -460,6 +460,26 @@ ErrorCode StrassenMatrixComputor::onEncode(int e, int l, int h, int as, int bs, return _generateMatMul(e, l, h, a, b, c, bias, 0, useBias); } +int StrassenMatrixComputor::getExecuteTime() { + // All is done in onResize, just execute it + auto res = CL_SUCCESS; + int executeTime = 0; + for (auto &unit : mUnits) { + if(unit.localWorkSize[0] == 0 || unit.localWorkSize[1] == 0) { + unit.localWorkSize = cl::NullRange; + } + cl::Event event; + res = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + executeTime += mOpenCLBackend->getOpenCLRuntime()->getEventTime(event); + } + return executeTime; +} + void StrassenMatrixComputor::onExecute() { // All is done in onResize, just execute it auto res = CL_SUCCESS; diff --git a/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.hpp b/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.hpp index dc6a9fa7..dd55399c 100644 --- a/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.hpp +++ b/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.hpp @@ -29,6 +29,7 @@ public: ErrorCode onEncode(int e, int l, int h, int as, int bs, int cs, const cl::Buffer AT, const cl::Buffer BT, cl::Buffer CT, bool useBias, const cl::Buffer Bias); + int getExecuteTime(); void onExecute(); void onReset(); diff --git a/source/backend/opencl/execution/cl/attention_buf.cl b/source/backend/opencl/execution/cl/attention_buf.cl index 49d83872..67495c56 100644 --- a/source/backend/opencl/execution/cl/attention_buf.cl +++ b/source/backend/opencl/execution/cl/attention_buf.cl @@ -418,6 +418,74 @@ __kernel void rearrange_v(GLOBAL_SIZE_3_DIMS #endif } +__kernel void rearrange_mask_shortprefill(GLOBAL_SIZE_3_DIMS + #ifdef ADD_MASK + __global const FLOAT* mask, + __global FLOAT* maskout, + #else + __global const int* mask, // [1 1 query_seq_len mask_key_seq_len4] + __global int* maskout, // [1 1 mask_key_seq_len4 query_seq_len4] + #endif + __private const int query_seq_len, + __private const int mask_key_seq_len){ + const int x = get_global_id(0); // query_seq_len4 + const int y = get_global_id(1); // mask_key_seq_len4 + const int z = get_global_id(2); // batch + DEAL_NON_UNIFORM_DIM3(x, y, z); + + const int x4 = x << 2; + const int y4 = y << 2; + float4 mask_tmp0, mask_tmp1, mask_tmp2, mask_tmp3; + float4 mask0, mask1, mask2, mask3; + int mask_offset = x4 * mask_key_seq_len + y4; + if(x4 + 3 < query_seq_len && y4 + 3 < mask_key_seq_len){ + mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; + mask_tmp1 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; + mask_tmp2 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; + mask_tmp3 = convert_float4(vload4(0, mask + mask_offset)); + } else{ + if(y4 + 3 < mask_key_seq_len){ + mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; + mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; + mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; + mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); + } else if(y4 + 1 == mask_key_seq_len){ + mask_tmp0 = (float4)(mask[mask_offset], 0, 0, 0); mask_offset += mask_key_seq_len; + mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], 0, 0, 0); mask_offset += mask_key_seq_len; + mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], 0, 0, 0); mask_offset += mask_key_seq_len; + mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], 0, 0, 0); + }else if(y4 + 2 == mask_key_seq_len){ + mask_tmp0 = (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); mask_offset += mask_key_seq_len; + mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); mask_offset += mask_key_seq_len; + mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); mask_offset += mask_key_seq_len; + mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], 0, 0); + }else if(y4 + 3 == mask_key_seq_len){ + mask_tmp0 = (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); mask_offset += mask_key_seq_len; + mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); mask_offset += mask_key_seq_len; + mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); mask_offset += mask_key_seq_len; + mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset], mask[mask_offset + 1], mask[mask_offset + 2], 0); + } + } + mask0 = (float4)(mask_tmp0.s0, mask_tmp1.s0, mask_tmp2.s0, mask_tmp3.s0); + mask1 = (float4)(mask_tmp0.s1, mask_tmp1.s1, mask_tmp2.s1, mask_tmp3.s1); + mask2 = (float4)(mask_tmp0.s2, mask_tmp1.s2, mask_tmp2.s2, mask_tmp3.s2); + mask3 = (float4)(mask_tmp0.s3, mask_tmp1.s3, mask_tmp2.s3, mask_tmp3.s3); + + int query_seq_len4 = ((query_seq_len + 3) / 4) * 4; + int output_offset = y4 * query_seq_len4 + x4; + #ifdef ADD_MASK + vstore4(CONVERT_FLOAT4(mask0), 0, maskout + output_offset); + vstore4(CONVERT_FLOAT4(mask1), 0, maskout + output_offset + query_seq_len4); + vstore4(CONVERT_FLOAT4(mask2), 0, maskout + output_offset + query_seq_len4 + query_seq_len4); + vstore4(CONVERT_FLOAT4(mask3), 0, maskout + output_offset + query_seq_len4 + query_seq_len4 + query_seq_len4); + #else + vstore4(convert_int4(mask0), 0, maskout + output_offset); + vstore4(convert_int4(mask1), 0, maskout + output_offset + query_seq_len4); + vstore4(convert_int4(mask2), 0, maskout + output_offset + query_seq_len4 + query_seq_len4); + vstore4(convert_int4(mask3), 0, maskout + output_offset + query_seq_len4 + query_seq_len4 + query_seq_len4); + #endif +} + __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS __global const FLOAT *query, // [batch head_num head_dim_4 query_seq_len_4] __global const FLOAT *past_key, // [batch kv_head_num head_dim_4 kv_max_length] @@ -486,15 +554,13 @@ __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS out3 *= (float4)scale; { #if defined(ADD_MASK) || defined(SET_MASK) - int mask_offset = x4 * mask_key_seq_len + y4; - float4 mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; - float4 mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; - float4 mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len; - float4 mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); - float4 mask0 = (float4)(mask_tmp0.s0, mask_tmp1.s0, mask_tmp2.s0, mask_tmp3.s0); - float4 mask1 = (float4)(mask_tmp0.s1, mask_tmp1.s1, mask_tmp2.s1, mask_tmp3.s1); - float4 mask2 = (float4)(mask_tmp0.s2, mask_tmp1.s2, mask_tmp2.s2, mask_tmp3.s2); - float4 mask3 = (float4)(mask_tmp0.s3, mask_tmp1.s3, mask_tmp2.s3, mask_tmp3.s3); + int query_seq_len4 = ((query_seq_len + 3) / 4) * 4; + int mask_clp = y4 + mask_key_seq_len - key_seq_len; + int mask_offset = mask_clp * query_seq_len4 + x4; + float4 mask0 = mask_clp >= 0 && mask_clp < mask_key_seq_len ? convert_float4(vload4(0, mask + mask_offset)) : 0; mask_offset += query_seq_len4; + float4 mask1 = mask_clp + 1 >= 0 && mask_clp + 1 < mask_key_seq_len? convert_float4(vload4(0, mask + mask_offset)) : 0; mask_offset += query_seq_len4; + float4 mask2 = mask_clp + 2 >= 0 && mask_clp + 2 < mask_key_seq_len? convert_float4(vload4(0, mask + mask_offset)) : 0; mask_offset += query_seq_len4; + float4 mask3 = mask_clp + 3 >= 0 && mask_clp + 3 < mask_key_seq_len? convert_float4(vload4(0, mask + mask_offset)) : 0; #endif #ifdef ADD_MASK diff --git a/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp index b4f4cd9a..94fea8f0 100644 --- a/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/attention_buf_mnn_cl.cpp @@ -361,6 +361,73 @@ const char* attention_buf = " vstore4(value_vec,0,past_value+output_offset);\n" "#endif\n" "}\n" +"__kernel void rearrange_mask_shortprefill(GLOBAL_SIZE_3_DIMS\n" +" #ifdef ADD_MASK\n" +" __global const FLOAT* mask,\n" +" __global FLOAT* maskout,\n" +" #else\n" +" __global const int* mask,// [1 1 query_seq_len mask_key_seq_len4]\n" +" __global int* maskout,// [1 1 mask_key_seq_len4 query_seq_len4]\n" +" #endif\n" +" __private const int query_seq_len,\n" +" __private const int mask_key_seq_len){\n" +" const int x=get_global_id(0); // query_seq_len4\n" +" const int y=get_global_id(1); // mask_key_seq_len4\n" +" const int z=get_global_id(2); // batch\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" +" \n" +" const int x4=x << 2;\n" +" const int y4=y << 2;\n" +" float4 mask_tmp0,mask_tmp1,mask_tmp2,mask_tmp3;\n" +" float4 mask0,mask1,mask2,mask3;\n" +" int mask_offset=x4*mask_key_seq_len+y4;\n" +" if(x4+3= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n" +" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n" +" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset));\n" +" } else if(y4+1 == mask_key_seq_len){\n" +" mask_tmp0=(float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n" +" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n" +" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n" +" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0);\n" +" }else if(y4+2 == mask_key_seq_len){\n" +" mask_tmp0=(float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n" +" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n" +" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n" +" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0);\n" +" }else if(y4+3 == mask_key_seq_len){\n" +" mask_tmp0=(float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n" +" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n" +" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n" +" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0);\n" +" }\n" +" }\n" +" mask0=(float4)(mask_tmp0.s0,mask_tmp1.s0,mask_tmp2.s0,mask_tmp3.s0);\n" +" mask1=(float4)(mask_tmp0.s1,mask_tmp1.s1,mask_tmp2.s1,mask_tmp3.s1);\n" +" mask2=(float4)(mask_tmp0.s2,mask_tmp1.s2,mask_tmp2.s2,mask_tmp3.s2);\n" +" mask3=(float4)(mask_tmp0.s3,mask_tmp1.s3,mask_tmp2.s3,mask_tmp3.s3);\n" +" \n" +" int query_seq_len4=((query_seq_len+3)/4)*4;\n" +" int output_offset=y4*query_seq_len4+x4;\n" +" #ifdef ADD_MASK\n" +" vstore4(CONVERT_FLOAT4(mask0),0,maskout+output_offset);\n" +" vstore4(CONVERT_FLOAT4(mask1),0,maskout+output_offset+query_seq_len4);\n" +" vstore4(CONVERT_FLOAT4(mask2),0,maskout+output_offset+query_seq_len4+query_seq_len4);\n" +" vstore4(CONVERT_FLOAT4(mask3),0,maskout+output_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n" +" #else\n" +" vstore4(convert_int4(mask0),0,maskout+output_offset);\n" +" vstore4(convert_int4(mask1),0,maskout+output_offset+query_seq_len4);\n" +" vstore4(convert_int4(mask2),0,maskout+output_offset+query_seq_len4+query_seq_len4);\n" +" vstore4(convert_int4(mask3),0,maskout+output_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n" +" #endif\n" +"}\n" "__kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS\n" " __global const FLOAT *query,// [batch head_num head_dim_4 query_seq_len_4]\n" " __global const FLOAT *past_key,// [batch kv_head_num head_dim_4 kv_max_length]\n" @@ -428,15 +495,13 @@ const char* attention_buf = " out3 *= (float4)scale;\n" " {\n" " #if defined(ADD_MASK) || defined(SET_MASK)\n" -" int mask_offset=x4*mask_key_seq_len+y4;\n" -" float4 mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n" -" float4 mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n" -" float4 mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n" -" float4 mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset));\n" -" float4 mask0=(float4)(mask_tmp0.s0,mask_tmp1.s0,mask_tmp2.s0,mask_tmp3.s0);\n" -" float4 mask1=(float4)(mask_tmp0.s1,mask_tmp1.s1,mask_tmp2.s1,mask_tmp3.s1);\n" -" float4 mask2=(float4)(mask_tmp0.s2,mask_tmp1.s2,mask_tmp2.s2,mask_tmp3.s2);\n" -" float4 mask3=(float4)(mask_tmp0.s3,mask_tmp1.s3,mask_tmp2.s3,mask_tmp3.s3);\n" +" int query_seq_len4=((query_seq_len+3)/4)*4;\n" +" int mask_clp=y4+mask_key_seq_len-key_seq_len;\n" +" int mask_offset=mask_clp*query_seq_len4+x4;\n" +" float4 mask0=mask_clp >= 0 && mask_clp= 0 && mask_clp+1= 0 && mask_clp+2= 0 && mask_clp+3> 4) - 8; - charWeights0.y = (charWeightsInt4.s0 & MOD_NUM) - 8; - charWeights0.z = (charWeightsInt4.s1 >> 4) - 8; - charWeights0.w = (charWeightsInt4.s1 & MOD_NUM) - 8; - charWeights1.x = (charWeightsInt4.s2 >> 4) - 8; - charWeights1.y = (charWeightsInt4.s2 & MOD_NUM) - 8; - charWeights1.z = (charWeightsInt4.s3 >> 4) - 8; - charWeights1.w = (charWeightsInt4.s3 & MOD_NUM)- 8; - charWeights2.x = (charWeightsInt4.s4 >> 4) - 8; - charWeights2.y = (charWeightsInt4.s4 & MOD_NUM) - 8; - charWeights2.z = (charWeightsInt4.s5 >> 4) - 8; - charWeights2.w = (charWeightsInt4.s5 & MOD_NUM) - 8; - charWeights3.x = (charWeightsInt4.s6 >> 4) - 8; - charWeights3.y = (charWeightsInt4.s6 & MOD_NUM) - 8; - charWeights3.z = (charWeightsInt4.s7 >> 4) - 8; - charWeights3.w = (charWeightsInt4.s7 & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeights0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeights1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeights2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeights3), scale0, offset0); -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(weights_width_base, weights + weight_offset); weights1 = vload4(weights_width_base + 1, weights + weight_offset); weights2 = vload4(weights_width_base + 2, weights + weight_offset); @@ -319,7 +252,6 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_block_idx)); weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_block_idx)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); @@ -368,17 +300,11 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, } __kernel -#if SET_ATTRIBUTE +#ifdef SET_ATTRIBUTE __attribute__((work_group_size_hint(16, 16, 1))) #endif void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER __global const FLOAT *weights, #else __read_only image2d_t weights, @@ -390,10 +316,6 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __private const int2 stride_shape, __private const int output_width_4, __private const int out_channel_blocks -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - ,__private const int blockDim - ,__private const int inChannel -#endif ) { const int output_channel_width_idx = get_global_id(0); @@ -404,13 +326,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, const int output_width_block_idx = output_channel_width_idx % output_width_4; const int output_channel_idx = output_channel_block_idx << 1; -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_ic_offset = output_channel_block_idx * 16; - int weight_oc_offset = out_channel_blocks * 8; -#else int weight_ic_offset = output_channel_block_idx * 32; int weight_oc_offset = out_channel_blocks * 16; -#endif FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_idx, 0)); FLOAT4 out1 = out0; FLOAT4 out2 = out0; @@ -457,16 +374,6 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, int weight_offset1 = weight_offset + in_channel_block * 4 * 4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; - // already pack to 16, no need boundry protect - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); - COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); - COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx + 1, dequantScaleOffset + kindex)); - COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); - COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); -#endif int input_width_base = in_channel_block_idx * input_shape.y; int weights_width_base = in_channel_block_idx << 2; @@ -475,72 +382,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - FLOAT16 weightsInt80 = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); - #ifdef CHANNEL_BOUNDARY_PROTECT - FLOAT16 weightsInt81 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); - #else - FLOAT16 weightsInt81 = CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); - #endif - FLOAT4 weights0 = CONVERT_FLOAT4(weightsInt80.s0123) * scale0 + offset0; - FLOAT4 weights1 = CONVERT_FLOAT4(weightsInt80.s4567) * scale0 + offset0; - FLOAT4 weights2 = CONVERT_FLOAT4(weightsInt80.s89ab) * scale0 + offset0; - FLOAT4 weights3 = CONVERT_FLOAT4(weightsInt80.scdef) * scale0 + offset0; - FLOAT4 weights4 = CONVERT_FLOAT4(weightsInt81.s0123) * scale1 + offset1; - FLOAT4 weights5 = CONVERT_FLOAT4(weightsInt81.s4567) * scale1 + offset1; - FLOAT4 weights6 = CONVERT_FLOAT4(weightsInt81.s89ab) * scale1 + offset1; - FLOAT4 weights7 = CONVERT_FLOAT4(weightsInt81.scdef) * scale1 + offset1; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar16 charWeightsInt4 = vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset); - char4 charWeights0 = (char4)(0, 0, 0, 0); - char4 charWeights1 = (char4)(0, 0, 0, 0); - char4 charWeights2 = (char4)(0, 0, 0, 0); - char4 charWeights3 = (char4)(0, 0, 0, 0); - char4 charWeights4 = (char4)(0, 0, 0, 0); - char4 charWeights5 = (char4)(0, 0, 0, 0); - char4 charWeights6 = (char4)(0, 0, 0, 0); - char4 charWeights7 = (char4)(0, 0, 0, 0); - charWeights0.x = (charWeightsInt4.s0 >> 4) - 8; - charWeights0.y = (charWeightsInt4.s0 & MOD_NUM) - 8; - charWeights0.z = (charWeightsInt4.s1 >> 4) - 8; - charWeights0.w = (charWeightsInt4.s1 & MOD_NUM) - 8; - charWeights1.x = (charWeightsInt4.s2 >> 4) - 8; - charWeights1.y = (charWeightsInt4.s2 & MOD_NUM) - 8; - charWeights1.z = (charWeightsInt4.s3 >> 4) - 8; - charWeights1.w = (charWeightsInt4.s3 & MOD_NUM) - 8; - charWeights2.x = (charWeightsInt4.s4 >> 4) - 8; - charWeights2.y = (charWeightsInt4.s4 & MOD_NUM) - 8; - charWeights2.z = (charWeightsInt4.s5 >> 4) - 8; - charWeights2.w = (charWeightsInt4.s5 & MOD_NUM) - 8; - charWeights3.x = (charWeightsInt4.s6 >> 4) - 8; - charWeights3.y = (charWeightsInt4.s6 & MOD_NUM) - 8; - charWeights3.z = (charWeightsInt4.s7 >> 4) - 8; - charWeights3.w = (charWeightsInt4.s7 & MOD_NUM) - 8; - charWeights4.x = (charWeightsInt4.s8 >> 4) - 8; - charWeights4.y = (charWeightsInt4.s8 & MOD_NUM) - 8; - charWeights4.z = (charWeightsInt4.s9 >> 4) - 8; - charWeights4.w = (charWeightsInt4.s9 & MOD_NUM) - 8; - charWeights5.x = (charWeightsInt4.sa >> 4) - 8; - charWeights5.y = (charWeightsInt4.sa & MOD_NUM) - 8; - charWeights5.z = (charWeightsInt4.sb >> 4) - 8; - charWeights5.w = (charWeightsInt4.sb & MOD_NUM) - 8; - charWeights6.x = (charWeightsInt4.sc >> 4) - 8; - charWeights6.y = (charWeightsInt4.sc & MOD_NUM) - 8; - charWeights6.z = (charWeightsInt4.sd >> 4) - 8; - charWeights6.w = (charWeightsInt4.sd & MOD_NUM) - 8; - charWeights7.x = (charWeightsInt4.se >> 4) - 8; - charWeights7.y = (charWeightsInt4.se & MOD_NUM) - 8; - charWeights7.z = (charWeightsInt4.sf >> 4) - 8; - charWeights7.w = (charWeightsInt4.sf & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeights0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeights1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeights2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeights3), scale0, offset0); - weights4 = mad(CONVERT_FLOAT4(charWeights4), scale1, offset1); - weights5 = mad(CONVERT_FLOAT4(charWeights5), scale1, offset1); - weights6 = mad(CONVERT_FLOAT4(charWeights6), scale1, offset1); - weights7 = mad(CONVERT_FLOAT4(charWeights7), scale1, offset1); -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(weights_width_base, weights + weight_offset); weights1 = vload4(weights_width_base + 1, weights + weight_offset); weights2 = vload4(weights_width_base + 2, weights + weight_offset); @@ -568,8 +410,6 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights6 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_idx + 1)); weights7 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_idx + 1)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); - PADZEROSVEC(in_channel_block_idx, inChannel, weights4, weights5, weights6, weights7); CALCULATE_OUTPUT(0); CALCULATE_OUTPUT(1); @@ -646,17 +486,11 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, } __kernel -#if SET_ATTRIBUTE +#ifdef SET_ATTRIBUTE __attribute__((work_group_size_hint(16, 16, 1))) #endif void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER __global const FLOAT *weights, #else __read_only image2d_t weights, @@ -675,10 +509,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __private const int out_width_blocks, __private const int out_channel_blocks, __private const int out_height_blocks -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - ,__private const int blockDim - ,__private const int inChannel -#endif ) { const int output_channel_width_idx = get_global_id(0); @@ -720,23 +550,15 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y); #endif -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER) +#ifdef USE_BUFFER const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4; #endif FLOAT4 in0, in1, in2, in3; FLOAT4 weights0, weights1, weights2, weights3; for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { - -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); - COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); -#endif - const int in_idx = mul24(in_channel_block_idx, input_shape.y); -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER) +#ifdef USE_BUFFER int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + kh_start)*weights_shape.y + 0) * 4; #else int weights_x_idx = in_channel_block_idx << 2; @@ -751,47 +573,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, READ_INPUT_IMAGE(2, 0); READ_INPUT_IMAGE(3, 0); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); - char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); - char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); - char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); - uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); - uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); - uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); - char4 charWeight0 = (char4)(0, 0, 0, 0); - char4 charWeight1 = (char4)(0, 0, 0, 0); - char4 charWeight2 = (char4)(0, 0, 0, 0); - char4 charWeight3 = (char4)(0, 0, 0, 0); - charWeight0.x = (charWeightInt40.s0 >> 4) - 8; - charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; - charWeight0.z = (charWeightInt40.s1 >> 4) - 8; - charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; - charWeight1.x = (charWeightInt41.s0 >> 4) - 8; - charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; - charWeight1.z = (charWeightInt41.s1 >> 4) - 8; - charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; - charWeight2.x = (charWeightInt42.s0 >> 4) - 8; - charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; - charWeight2.z = (charWeightInt42.s1 >> 4) - 8; - charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; - charWeight3.x = (charWeightInt43.s0 >> 4) - 8; - charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; - charWeight3.z = (charWeightInt43.s1 >> 4) - 8; - charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(0, weights+weight_offset); weights1 = vload4(0, weights+weight_offset+weight_oc_offset); weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); @@ -803,7 +585,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); CALCULATE_OUTPUT(0); CALCULATE_OUTPUT(1); CALCULATE_OUTPUT(2); @@ -814,47 +595,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, in1 = in2; in2 = in3; READ_INPUT_IMAGE(3, w); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); - char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); - char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); - char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); - uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); - uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); - uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); - char4 charWeight0 = (char4)(0, 0, 0, 0); - char4 charWeight1 = (char4)(0, 0, 0, 0); - char4 charWeight2 = (char4)(0, 0, 0, 0); - char4 charWeight3 = (char4)(0, 0, 0, 0); - charWeight0.x = (charWeightInt40.s0 >> 4) - 8; - charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; - charWeight0.z = (charWeightInt40.s1 >> 4) - 8; - charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; - charWeight1.x = (charWeightInt41.s0 >> 4) - 8; - charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; - charWeight1.z = (charWeightInt41.s1 >> 4) - 8; - charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; - charWeight2.x = (charWeightInt42.s0 >> 4) - 8; - charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; - charWeight2.z = (charWeightInt42.s1 >> 4) - 8; - charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; - charWeight3.x = (charWeightInt43.s0 >> 4) - 8; - charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; - charWeight3.z = (charWeightInt43.s1 >> 4) - 8; - charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(0, weights+weight_offset); weights1 = vload4(0, weights+weight_offset+weight_oc_offset); weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); @@ -866,7 +607,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); CALCULATE_OUTPUT(0); CALCULATE_OUTPUT(1); CALCULATE_OUTPUT(2); @@ -879,47 +619,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, READ_INPUT_IMAGE(1, input_width_base); READ_INPUT_IMAGE(2, input_width_base); READ_INPUT_IMAGE(3, input_width_base); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); - char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); - char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); - char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); - uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); - uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); - uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); - char4 charWeight0 = (char4)(0, 0, 0, 0); - char4 charWeight1 = (char4)(0, 0, 0, 0); - char4 charWeight2 = (char4)(0, 0, 0, 0); - char4 charWeight3 = (char4)(0, 0, 0, 0); - charWeight0.x = (charWeightInt40.s0 >> 4) - 8; - charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; - charWeight0.z = (charWeightInt40.s1 >> 4) - 8; - charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; - charWeight1.x = (charWeightInt41.s0 >> 4) - 8; - charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; - charWeight1.z = (charWeightInt41.s1 >> 4) - 8; - charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; - charWeight2.x = (charWeightInt42.s0 >> 4) - 8; - charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; - charWeight2.z = (charWeightInt42.s1 >> 4) - 8; - charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; - charWeight3.x = (charWeightInt43.s0 >> 4) - 8; - charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; - charWeight3.z = (charWeightInt43.s1 >> 4) - 8; - charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(0, weights+weight_offset); weights1 = vload4(0, weights+weight_offset+weight_oc_offset); weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); @@ -931,7 +631,6 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); CALCULATE_OUTPUT(0); CALCULATE_OUTPUT(1); CALCULATE_OUTPUT(2); @@ -978,17 +677,11 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, } __kernel -#if SET_ATTRIBUTE +#ifdef SET_ATTRIBUTE __attribute__((work_group_size_hint(16, 16, 1))) #endif void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER __global const FLOAT *weights, #else __read_only image2d_t weights, @@ -1007,10 +700,6 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __private const int out_width_blocks, __private const int out_channel_blocks, __private const int out_height_blocks -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - ,__private const int blockDim - ,__private const int inChannel -#endif ) { const int output_channel_width_idx = get_global_id(0); @@ -1036,7 +725,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, FLOAT4 out6 = out4; FLOAT4 out7 = out4; -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER) +#ifdef USE_BUFFER const int weight_oc_offset = weights_shape.x * weights_shape.y * 4; const int weight_ic_offset = out_channel_blocks * weight_oc_offset; #endif @@ -1054,18 +743,8 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, FLOAT4 in0, in1, in2, in3; FLOAT4 weights0, weights1, weights2, weights3, weights4, weights5, weights6, weights7; for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); - COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); - COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx + 1, dequantScaleOffset + kindex)); - COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); - COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); - -#endif const int in_idx = mul24(in_channel_block_idx, input_shape.y); -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER) +#ifdef USE_BUFFER int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4; #else int weights_x_idx = in_channel_block_idx << 2; @@ -1083,92 +762,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, in1 = RI_F(input, SAMPLER, (int2)(w0, h1)); in2 = RI_F(input, SAMPLER, (int2)(w0, h2)); in3 = RI_F(input, SAMPLER, (int2)(w0, h3)); - -#if (defined USE_LOW_BIT_WEIGHT_INT8) - char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); - char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset); - char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset*2); - char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset*3); - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - #ifdef CHANNEL_BOUNDARY_PROTECT - charWeight0 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset); - charWeight1 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset); - charWeight2 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2); - charWeight3 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3); - - #else - charWeight0 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); - charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset); - charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2); - charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3); - #endif - weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1); - weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1); - weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1); - weights7 = mad(CONVERT_FLOAT4(charWeight3), scale1, offset1); - weight_offset += 4; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); - uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset/2); - uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset*2/2); - uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset*3/2); - char4 charWeight0 = (char4)(0, 0, 0, 0); - char4 charWeight1 = (char4)(0, 0, 0, 0); - char4 charWeight2 = (char4)(0, 0, 0, 0); - char4 charWeight3 = (char4)(0, 0, 0, 0); - charWeight0.x = (charWeightInt40.s0 >> 4) - 8; - charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; - charWeight0.z = (charWeightInt40.s1 >> 4) - 8; - charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; - charWeight1.x = (charWeightInt41.s0 >> 4) - 8; - charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; - charWeight1.z = (charWeightInt41.s1 >> 4) - 8; - charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; - charWeight2.x = (charWeightInt42.s0 >> 4) - 8; - charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; - charWeight2.z = (charWeightInt42.s1 >> 4) - 8; - charWeight2.w = (charWeightInt42.s1 & MOD_NUM)- 8; - charWeight3.x = (charWeightInt43.s0 >> 4) - 8; - charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; - charWeight3.z = (charWeightInt43.s1 >> 4) - 8; - charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); - charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2); - charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2); - charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2); - charWeight0 = (char4)(0, 0, 0, 0); - charWeight1 = (char4)(0, 0, 0, 0); - charWeight2 = (char4)(0, 0, 0, 0); - charWeight3 = (char4)(0, 0, 0, 0); - charWeight0.x = (charWeightInt40.s0 >> 4) - 8; - charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; - charWeight0.z = (charWeightInt40.s1 >> 4) - 8; - charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; - charWeight1.x = (charWeightInt41.s0 >> 4) - 8; - charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; - charWeight1.z = (charWeightInt41.s1 >> 4) - 8; - charWeight1.w = (charWeightInt41.s1 & MOD_NUM)- 8; - charWeight2.x = (charWeightInt42.s0 >> 4) - 8; - charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; - charWeight2.z = (charWeightInt42.s1 >> 4) - 8; - charWeight2.w = (charWeightInt42.s1 & MOD_NUM)- 8; - charWeight3.x = (charWeightInt43.s0 >> 4) - 8; - charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; - charWeight3.z = (charWeightInt43.s1 >> 4) - 8; - charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; - weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1); - weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1); - weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1); - weights7 = mad(CONVERT_FLOAT4(charWeight3), scale1, offset1); - weight_offset += 4; -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(0, weights+weight_offset); weights1 = vload4(0, weights+weight_offset+weight_ic_offset); weights2 = vload4(0, weights+weight_offset+weight_ic_offset*2); @@ -1195,8 +789,6 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights6 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weight_size + weights_y_idx)); weights7 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weight_size + weights_y_idx++)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); - PADZEROSVEC(in_channel_block_idx, inChannel, weights4, weights5, weights6, weights7); CALCULATE_OUTPUT(0); CALCULATE_OUTPUT(1); @@ -1279,17 +871,11 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, } __kernel -#if SET_ATTRIBUTE +#ifdef SET_ATTRIBUTE __attribute__((work_group_size_hint(16, 16, 1))) #endif void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *kernel_ptr, - __global const float *dequantScaleOffset, -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER __global const FLOAT *weights, #else __read_only image2d_t weights, @@ -1308,10 +894,6 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __private const int out_width_blocks, __private const int out_channel_blocks, __private const int out_height_blocks -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - ,__private const int blockDim - ,__private const int inChannel -#endif ) { const int output_channel_width_idx = get_global_id(0); @@ -1344,18 +926,12 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, FLOAT4 in0, in1, in2, in3; FLOAT4 weights0, weights1, weights2, weights3; -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER) +#ifdef USE_BUFFER const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4; #endif for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); - COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); -#endif const int in_idx = mul24(in_channel_block_idx, input_shape.y); -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER) +#ifdef USE_BUFFER int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4; #else int weights_x_idx = in_channel_block_idx << 2; @@ -1374,47 +950,7 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, in2 = RI_F(input, SAMPLER, (int2)(w0, h2)); in3 = RI_F(input, SAMPLER, (int2)(w0, h3)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); - char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); - char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); - char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); - uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); - uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); - uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); - char4 charWeight0 = (char4)(0, 0, 0, 0); - char4 charWeight1 = (char4)(0, 0, 0, 0); - char4 charWeight2 = (char4)(0, 0, 0, 0); - char4 charWeight3 = (char4)(0, 0, 0, 0); - charWeight0.x = (charWeightInt40.s0 >> 4) - 8; - charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; - charWeight0.z = (charWeightInt40.s1 >> 4) - 8; - charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; - charWeight1.x = (charWeightInt41.s0 >> 4) - 8; - charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; - charWeight1.z = (charWeightInt41.s1 >> 4) - 8; - charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; - charWeight2.x = (charWeightInt42.s0 >> 4) - 8; - charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; - charWeight2.z = (charWeightInt42.s1 >> 4) - 8; - charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; - charWeight3.x = (charWeightInt43.s0 >> 4) - 8; - charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; - charWeight3.z = (charWeightInt43.s1 >> 4) - 8; - charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; - weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); - weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); - weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); - weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); - weight_offset += 4; -#elif (defined USE_BUFFER) +#ifdef USE_BUFFER weights0 = vload4(0, weights+weight_offset); weights1 = vload4(0, weights+weight_offset+weight_oc_offset); weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); @@ -1426,7 +962,6 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); #endif - PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); CALCULATE_OUTPUT(0); CALCULATE_OUTPUT(1); diff --git a/source/backend/opencl/execution/cl/conv_2d_int.cl b/source/backend/opencl/execution/cl/conv_2d_int.cl new file mode 100644 index 00000000..ed5447a6 --- /dev/null +++ b/source/backend/opencl/execution/cl/conv_2d_int.cl @@ -0,0 +1,1179 @@ +#ifdef MNN_SUPPORT_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif +#define READ_INPUT_IMAGE(i, base) \ + int in_width_value##i = in_width##i + base; \ + in_width_value##i = \ + select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \ + in##i = RI_F(input, SAMPLER, (int2)(in_width_value##i, in_hb_value)); + +#define CALCULATE_OUTPUT(i) \ + out##i = mad(in##i.x, weights0, out##i); \ + out##i = mad(in##i.y, weights1, out##i); \ + out##i = mad(in##i.z, weights2, out##i); \ + out##i = mad(in##i.w, weights3, out##i); + +#define CALCULATE_OUTPUT_WEIGHTS4(i, j) \ + out##i = mad(in##j.x, weights4, out##i); \ + out##i = mad(in##j.y, weights5, out##i); \ + out##i = mad(in##j.z, weights6, out##i); \ + out##i = mad(in##j.w, weights7, out##i); + +#define CALCULATE_OUTPUT_OPT(i) \ + out##i = mad(in_sm##i[local_idx].x, weights0, out##i); \ + out##i = mad(in_sm##i[local_idx].y, weights1, out##i); \ + out##i = mad(in_sm##i[local_idx].z, weights2, out##i); \ + out##i = mad(in_sm##i[local_idx].w, weights3, out##i); + +#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1, + +__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ + return; \ + } + +#define GLOBAL_SIZE_3_DIMS \ + __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, + +#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ + return; \ + } + +#define UNIT 4 +#define MOD_NUM 15 + +#ifdef INPUT_CHANNEL_LEAVE + #define PADZEROSVEC(k, channel, data0, data1, data2, data3) \ + data0 = (k << 2) < channel ? data0 : 0; \ + data1 = (k << 2) + 1 < channel ? data1 : 0; \ + data2 = (k << 2) + 2 < channel ? data2 : 0; \ + data3 = (k << 2) + 3 < channel ? data3 : 0; +#else + #define PADZEROSVEC(k, channel, data0, data1, data2, data3) +#endif + +__kernel +#ifdef SET_ATTRIBUTE +__attribute__((work_group_size_hint(16, 16, 1))) +#endif +void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *kernel_ptr, + __global const float *dequantScaleOffset, +#else + __global const uchar *kernel_ptr, + __global const float *dequantScaleOffset, +#endif + __read_only image2d_t bias, + __write_only image2d_t output, + __private const int2 input_shape, + __private const int in_channel_block, __private const int2 output_shape, + __private const int2 stride_shape, + __private const int output_width_4, + __private const int out_channel_blocks + ,__private const int blockDim + ,__private const int inChannel +) { + + const int output_channel_width_idx = get_global_id(0); + const int output_batch_height_idx = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); + + const int output_channel_block_idx = output_channel_width_idx / output_width_4; + const int output_width_block_idx = output_channel_width_idx % output_width_4; + +#if QUANT_BIT == 4 + int weight_ic_offset = output_channel_block_idx * 8; + int weight_oc_offset = out_channel_blocks * 8; +#else + int weight_ic_offset = output_channel_block_idx * 16; + int weight_oc_offset = out_channel_blocks * 16; +#endif + + FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_block_idx, 0)); + FLOAT4 out1 = out0; + FLOAT4 out2 = out0; + FLOAT4 out3 = out0; + +#ifdef MNN_CONV_S1D1 + int intput_width_idx0 = output_width_block_idx << 2; + int intput_width_idx1 = intput_width_idx0 + 1; + int intput_width_idx2 = intput_width_idx0 + 2; + int intput_width_idx3 = intput_width_idx0 + 3; +#else + int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4); + int intput_width_idx1 = intput_width_idx0 + stride_shape.y; + int intput_width_idx2 = intput_width_idx1 + stride_shape.y; + int intput_width_idx3 = intput_width_idx2 + stride_shape.y; + + intput_width_idx0 = select(intput_width_idx0, INT_MIN, intput_width_idx0 >= input_shape.y); + intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y); + intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y); + intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y); +#endif + + int batch_index = output_batch_height_idx / output_shape.x; + int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x; + + FLOAT4 in0; + FLOAT4 in1; + FLOAT4 in2; + FLOAT4 in3; + FLOAT4 weights0; + FLOAT4 weights1; + FLOAT4 weights2; + FLOAT4 weights3; + int weight_offset = output_channel_block_idx * in_channel_block * 4 * 4; + + for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { + int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_block_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); + COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); + + int input_width_base = in_channel_block_idx * input_shape.y; + int weights_width_base = in_channel_block_idx << 2; + +#if QUANT_BIT == 8 + FLOAT16 weights = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + FLOAT4 weights0 = CONVERT_FLOAT4(weights.s0123) * scale0 + offset0; + FLOAT4 weights1 = CONVERT_FLOAT4(weights.s4567) * scale0 + offset0; + FLOAT4 weights2 = CONVERT_FLOAT4(weights.s89ab) * scale0 + offset0; + FLOAT4 weights3 = CONVERT_FLOAT4(weights.scdef) * scale0 + offset0; +#else + uchar8 charWeightsInt4 = vload8(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset); + char4 charWeights0 = (char4)(0, 0, 0, 0); + char4 charWeights1 = (char4)(0, 0, 0, 0); + char4 charWeights2 = (char4)(0, 0, 0, 0); + char4 charWeights3 = (char4)(0, 0, 0, 0); + charWeights0.x = (charWeightsInt4.s0 >> 4) - 8; + charWeights0.y = (charWeightsInt4.s0 & MOD_NUM) - 8; + charWeights0.z = (charWeightsInt4.s1 >> 4) - 8; + charWeights0.w = (charWeightsInt4.s1 & MOD_NUM) - 8; + charWeights1.x = (charWeightsInt4.s2 >> 4) - 8; + charWeights1.y = (charWeightsInt4.s2 & MOD_NUM) - 8; + charWeights1.z = (charWeightsInt4.s3 >> 4) - 8; + charWeights1.w = (charWeightsInt4.s3 & MOD_NUM)- 8; + charWeights2.x = (charWeightsInt4.s4 >> 4) - 8; + charWeights2.y = (charWeightsInt4.s4 & MOD_NUM) - 8; + charWeights2.z = (charWeightsInt4.s5 >> 4) - 8; + charWeights2.w = (charWeightsInt4.s5 & MOD_NUM) - 8; + charWeights3.x = (charWeightsInt4.s6 >> 4) - 8; + charWeights3.y = (charWeightsInt4.s6 & MOD_NUM) - 8; + charWeights3.z = (charWeightsInt4.s7 >> 4) - 8; + charWeights3.w = (charWeightsInt4.s7 & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeights0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeights1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeights2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeights3), scale0, offset0); +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); + in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); + in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); + in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); + + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + } + +#ifdef RELU + out0 = fmax(out0, (FLOAT4)0); + out1 = fmax(out1, (FLOAT4)0); + out2 = fmax(out2, (FLOAT4)0); + out3 = fmax(out3, (FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); + out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); + out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); + out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); +#endif + + const int out_x_base = mul24(output_channel_block_idx, output_shape.y); + int out_x_idx = output_width_block_idx << 2; + + const int remain = output_shape.y - out_x_idx; + int output_idx = out_x_base + out_x_idx; + if (remain >= 4) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); + WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3); + } else if (remain == 3) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); + } else if (remain == 2) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + } else if (remain == 1) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + } +} + +__kernel +#ifdef SET_ATTRIBUTE +__attribute__((work_group_size_hint(16, 16, 1))) +#endif +void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *kernel_ptr, + __global const float *dequantScaleOffset, +#else + __global const uchar *kernel_ptr, + __global const float *dequantScaleOffset, +#endif + __read_only image2d_t bias, + __write_only image2d_t output, + __private const int2 input_shape, + __private const int in_channel_block, __private const int2 output_shape, + __private const int2 stride_shape, + __private const int output_width_4, + __private const int out_channel_blocks + ,__private const int blockDim + ,__private const int inChannel +) { + + const int output_channel_width_idx = get_global_id(0); + const int output_batch_height_idx = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); + + const int output_channel_block_idx = output_channel_width_idx / output_width_4; + const int output_width_block_idx = output_channel_width_idx % output_width_4; + const int output_channel_idx = output_channel_block_idx << 1; + +#if QUANT_BIT == 4 + int weight_ic_offset = output_channel_block_idx * 16; + int weight_oc_offset = out_channel_blocks * 8; +#else + int weight_ic_offset = output_channel_block_idx * 32; + int weight_oc_offset = out_channel_blocks * 16; +#endif + FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_idx, 0)); + FLOAT4 out1 = out0; + FLOAT4 out2 = out0; + FLOAT4 out3 = out0; + + FLOAT4 out4 = RI_F(bias, SAMPLER, (int2)(output_channel_idx + 1, 0)); + FLOAT4 out5 = out4; + FLOAT4 out6 = out4; + FLOAT4 out7 = out4; + +#ifdef MNN_CONV_S1D1 + int intput_width_idx0 = output_width_block_idx << 2; + int intput_width_idx1 = intput_width_idx0 + 1; + int intput_width_idx2 = intput_width_idx0 + 2; + int intput_width_idx3 = intput_width_idx0 + 3; +#else + int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4); + int intput_width_idx1 = intput_width_idx0 + stride_shape.y; + int intput_width_idx2 = intput_width_idx1 + stride_shape.y; + int intput_width_idx3 = intput_width_idx2 + stride_shape.y; + + intput_width_idx0 = select(intput_width_idx0, INT_MIN, intput_width_idx0 >= input_shape.y); + intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y); + intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y); + intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y); +#endif + + int batch_index = output_batch_height_idx / output_shape.x; + int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x; + + FLOAT4 in0; + FLOAT4 in1; + FLOAT4 in2; + FLOAT4 in3; + FLOAT4 weights0; + FLOAT4 weights1; + FLOAT4 weights2; + FLOAT4 weights3; + FLOAT4 weights4; + FLOAT4 weights5; + FLOAT4 weights6; + FLOAT4 weights7; + int weight_offset = output_channel_idx * in_channel_block * 4 * 4; + int weight_offset1 = weight_offset + in_channel_block * 4 * 4; + + for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { + int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; + // already pack to 16, no need boundry protect + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); + COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); + COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx + 1, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); + COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); + + int input_width_base = in_channel_block_idx * input_shape.y; + int weights_width_base = in_channel_block_idx << 2; + in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); + in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); + in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); + in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); + +#if QUANT_BIT == 8 + FLOAT16 weightsInt80 = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + #ifdef CHANNEL_BOUNDARY_PROTECT + FLOAT16 weightsInt81 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + #else + FLOAT16 weightsInt81 = CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + #endif + FLOAT4 weights0 = CONVERT_FLOAT4(weightsInt80.s0123) * scale0 + offset0; + FLOAT4 weights1 = CONVERT_FLOAT4(weightsInt80.s4567) * scale0 + offset0; + FLOAT4 weights2 = CONVERT_FLOAT4(weightsInt80.s89ab) * scale0 + offset0; + FLOAT4 weights3 = CONVERT_FLOAT4(weightsInt80.scdef) * scale0 + offset0; + FLOAT4 weights4 = CONVERT_FLOAT4(weightsInt81.s0123) * scale1 + offset1; + FLOAT4 weights5 = CONVERT_FLOAT4(weightsInt81.s4567) * scale1 + offset1; + FLOAT4 weights6 = CONVERT_FLOAT4(weightsInt81.s89ab) * scale1 + offset1; + FLOAT4 weights7 = CONVERT_FLOAT4(weightsInt81.scdef) * scale1 + offset1; +#else + uchar16 charWeightsInt4 = vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset); + char4 charWeights0 = (char4)(0, 0, 0, 0); + char4 charWeights1 = (char4)(0, 0, 0, 0); + char4 charWeights2 = (char4)(0, 0, 0, 0); + char4 charWeights3 = (char4)(0, 0, 0, 0); + char4 charWeights4 = (char4)(0, 0, 0, 0); + char4 charWeights5 = (char4)(0, 0, 0, 0); + char4 charWeights6 = (char4)(0, 0, 0, 0); + char4 charWeights7 = (char4)(0, 0, 0, 0); + charWeights0.x = (charWeightsInt4.s0 >> 4) - 8; + charWeights0.y = (charWeightsInt4.s0 & MOD_NUM) - 8; + charWeights0.z = (charWeightsInt4.s1 >> 4) - 8; + charWeights0.w = (charWeightsInt4.s1 & MOD_NUM) - 8; + charWeights1.x = (charWeightsInt4.s2 >> 4) - 8; + charWeights1.y = (charWeightsInt4.s2 & MOD_NUM) - 8; + charWeights1.z = (charWeightsInt4.s3 >> 4) - 8; + charWeights1.w = (charWeightsInt4.s3 & MOD_NUM) - 8; + charWeights2.x = (charWeightsInt4.s4 >> 4) - 8; + charWeights2.y = (charWeightsInt4.s4 & MOD_NUM) - 8; + charWeights2.z = (charWeightsInt4.s5 >> 4) - 8; + charWeights2.w = (charWeightsInt4.s5 & MOD_NUM) - 8; + charWeights3.x = (charWeightsInt4.s6 >> 4) - 8; + charWeights3.y = (charWeightsInt4.s6 & MOD_NUM) - 8; + charWeights3.z = (charWeightsInt4.s7 >> 4) - 8; + charWeights3.w = (charWeightsInt4.s7 & MOD_NUM) - 8; + charWeights4.x = (charWeightsInt4.s8 >> 4) - 8; + charWeights4.y = (charWeightsInt4.s8 & MOD_NUM) - 8; + charWeights4.z = (charWeightsInt4.s9 >> 4) - 8; + charWeights4.w = (charWeightsInt4.s9 & MOD_NUM) - 8; + charWeights5.x = (charWeightsInt4.sa >> 4) - 8; + charWeights5.y = (charWeightsInt4.sa & MOD_NUM) - 8; + charWeights5.z = (charWeightsInt4.sb >> 4) - 8; + charWeights5.w = (charWeightsInt4.sb & MOD_NUM) - 8; + charWeights6.x = (charWeightsInt4.sc >> 4) - 8; + charWeights6.y = (charWeightsInt4.sc & MOD_NUM) - 8; + charWeights6.z = (charWeightsInt4.sd >> 4) - 8; + charWeights6.w = (charWeightsInt4.sd & MOD_NUM) - 8; + charWeights7.x = (charWeightsInt4.se >> 4) - 8; + charWeights7.y = (charWeightsInt4.se & MOD_NUM) - 8; + charWeights7.z = (charWeightsInt4.sf >> 4) - 8; + charWeights7.w = (charWeightsInt4.sf & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeights0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeights1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeights2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeights3), scale0, offset0); + weights4 = mad(CONVERT_FLOAT4(charWeights4), scale1, offset1); + weights5 = mad(CONVERT_FLOAT4(charWeights5), scale1, offset1); + weights6 = mad(CONVERT_FLOAT4(charWeights6), scale1, offset1); + weights7 = mad(CONVERT_FLOAT4(charWeights7), scale1, offset1); +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + PADZEROSVEC(in_channel_block_idx, inChannel, weights4, weights5, weights6, weights7); + + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + + CALCULATE_OUTPUT_WEIGHTS4(4, 0); + CALCULATE_OUTPUT_WEIGHTS4(5, 1); + CALCULATE_OUTPUT_WEIGHTS4(6, 2); + CALCULATE_OUTPUT_WEIGHTS4(7, 3); + } + +#ifdef RELU + out0 = fmax(out0, (FLOAT4)0); + out1 = fmax(out1, (FLOAT4)0); + out2 = fmax(out2, (FLOAT4)0); + out3 = fmax(out3, (FLOAT4)0); + out4 = fmax(out4, (FLOAT4)0); + out5 = fmax(out5, (FLOAT4)0); + out6 = fmax(out6, (FLOAT4)0); + out7 = fmax(out7, (FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); + out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); + out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); + out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); + out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6); + out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6); + out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6); + out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6); +#endif + + const int out_x_base = mul24(output_channel_idx, output_shape.y); + int out_x_idx = output_width_block_idx << 2; + + const int remain = output_shape.y - out_x_idx; + int output_idx = out_x_base + out_x_idx; + if (remain >= 4) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); + WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3); + } else if (remain == 3) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); + } else if (remain == 2) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + } else if (remain == 1) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + } + + if(output_channel_idx + 1 >= out_channel_blocks) + return; + output_idx += output_shape.y; + if (remain >= 4) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out6); + WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out7); + } else if (remain == 3) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out6); + } else if (remain == 2) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5); + } else if (remain == 1) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); + } +} + +__kernel +#ifdef SET_ATTRIBUTE +__attribute__((work_group_size_hint(16, 16, 1))) +#endif +void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *kernel_ptr, + __global const float *dequantScaleOffset, +#else + __global const uchar *kernel_ptr, + __global const float *dequantScaleOffset, +#endif +#ifdef BIAS + __read_only image2d_t bias, +#endif + __write_only image2d_t output, + __private const int2 input_shape, + __private const int in_channel_block_length, + __private const int2 output_shape, + __private const int2 weights_shape, + __private const int2 stride_shape, + __private const int2 padding_shape, + __private const int2 dilation_shape, + __private const int out_width_blocks, + __private const int out_channel_blocks, + __private const int out_height_blocks + ,__private const int blockDim + ,__private const int inChannel +) { + + const int output_channel_width_idx = get_global_id(0); + const int output_batch_height_idx = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); + + const int out_channel_block_idx = output_channel_width_idx / out_width_blocks; + const int out_height_block_idx = output_channel_width_idx % out_width_blocks; + +#ifdef BIAS + FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_block_idx, 0)); +#else + FLOAT4 out0 = (FLOAT4)0; +#endif + FLOAT4 out1 = out0; + FLOAT4 out2 = out0; + FLOAT4 out3 = out0; + + int in_width0 = mad24(out_height_block_idx, stride_shape.y<<2, -padding_shape.y); + int in_width1 = in_width0 + stride_shape.y; + int in_width2 = in_width0 + stride_shape.y * 2; + int in_width3 = in_width0 + stride_shape.y * 3; + +#ifdef MNN_CONV_S1D1 + const int height_start = mad24((output_batch_height_idx % output_shape.x), 1, -padding_shape.x); + const int kh_start = select(0, (-height_start), height_start < 0); + int in_height_start = kh_start + height_start; + int in_height_end = min(weights_shape.x + height_start, input_shape.x); + + const int batch_idx = mul24((output_batch_height_idx / output_shape.x), input_shape.x); + const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start), height_start < 0), weights_shape.y); +#else + const int height_start = mad24((output_batch_height_idx % output_shape.x), stride_shape.x, -padding_shape.x); + const int kh_start = select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0); + int in_height_start = mad24(kh_start, dilation_shape.x, height_start); + int in_height_end = min(mad24(weights_shape.x, dilation_shape.x, height_start), input_shape.x); + + const int batch_idx = mul24((output_batch_height_idx / output_shape.x), input_shape.x); + const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y); +#endif + + const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4; + + FLOAT4 in0, in1, in2, in3; + FLOAT4 weights0, weights1, weights2, weights3; + for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { + + int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); + COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); + + const int in_idx = mul24(in_channel_block_idx, input_shape.y); + int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + kh_start)*weights_shape.y + 0) * 4; + for (int iy = in_height_start; iy < in_height_end; iy += dilation_shape.x) { + int in_hb_value = iy + batch_idx; +#ifdef MNN_CONV_S1D1 + { + READ_INPUT_IMAGE(0, 0); + READ_INPUT_IMAGE(1, 0); + READ_INPUT_IMAGE(2, 0); + READ_INPUT_IMAGE(3, 0); + +#if QUANT_BIT == 8 + char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); + char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); + char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#else + uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); + uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); + uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); + uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); + char4 charWeight0 = (char4)(0, 0, 0, 0); + char4 charWeight1 = (char4)(0, 0, 0, 0); + char4 charWeight2 = (char4)(0, 0, 0, 0); + char4 charWeight3 = (char4)(0, 0, 0, 0); + charWeight0.x = (charWeightInt40.s0 >> 4) - 8; + charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; + charWeight0.z = (charWeightInt40.s1 >> 4) - 8; + charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; + charWeight1.x = (charWeightInt41.s0 >> 4) - 8; + charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; + charWeight1.z = (charWeightInt41.s1 >> 4) - 8; + charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; + charWeight2.x = (charWeightInt42.s0 >> 4) - 8; + charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; + charWeight2.z = (charWeightInt42.s1 >> 4) - 8; + charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; + charWeight3.x = (charWeightInt43.s0 >> 4) - 8; + charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; + charWeight3.z = (charWeightInt43.s1 >> 4) - 8; + charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + } + for (int w = 1; w < weights_shape.y; w++){ + in0 = in1; + in1 = in2; + in2 = in3; + READ_INPUT_IMAGE(3, w); +#if QUANT_BIT == 8 + char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); + char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); + char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#else + uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); + uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); + uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); + uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); + char4 charWeight0 = (char4)(0, 0, 0, 0); + char4 charWeight1 = (char4)(0, 0, 0, 0); + char4 charWeight2 = (char4)(0, 0, 0, 0); + char4 charWeight3 = (char4)(0, 0, 0, 0); + charWeight0.x = (charWeightInt40.s0 >> 4) - 8; + charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; + charWeight0.z = (charWeightInt40.s1 >> 4) - 8; + charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; + charWeight1.x = (charWeightInt41.s0 >> 4) - 8; + charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; + charWeight1.z = (charWeightInt41.s1 >> 4) - 8; + charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; + charWeight2.x = (charWeightInt42.s0 >> 4) - 8; + charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; + charWeight2.z = (charWeightInt42.s1 >> 4) - 8; + charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; + charWeight3.x = (charWeightInt43.s0 >> 4) - 8; + charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; + charWeight3.z = (charWeightInt43.s1 >> 4) - 8; + charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + } +#else + for (int w = 0; w < weights_shape.y; w++) { + int input_width_base = mul24(w, dilation_shape.y); + READ_INPUT_IMAGE(0, input_width_base); + READ_INPUT_IMAGE(1, input_width_base); + READ_INPUT_IMAGE(2, input_width_base); + READ_INPUT_IMAGE(3, input_width_base); +#if QUANT_BIT == 8 + char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); + char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); + char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#else + uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); + uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); + uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); + uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); + char4 charWeight0 = (char4)(0, 0, 0, 0); + char4 charWeight1 = (char4)(0, 0, 0, 0); + char4 charWeight2 = (char4)(0, 0, 0, 0); + char4 charWeight3 = (char4)(0, 0, 0, 0); + charWeight0.x = (charWeightInt40.s0 >> 4) - 8; + charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; + charWeight0.z = (charWeightInt40.s1 >> 4) - 8; + charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; + charWeight1.x = (charWeightInt41.s0 >> 4) - 8; + charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; + charWeight1.z = (charWeightInt41.s1 >> 4) - 8; + charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; + charWeight2.x = (charWeightInt42.s0 >> 4) - 8; + charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; + charWeight2.z = (charWeightInt42.s1 >> 4) - 8; + charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; + charWeight3.x = (charWeightInt43.s0 >> 4) - 8; + charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; + charWeight3.z = (charWeightInt43.s1 >> 4) - 8; + charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + } +#endif + } + } + +#ifdef RELU + out0 = fmax(out0, (FLOAT4)0); + out1 = fmax(out1, (FLOAT4)0); + out2 = fmax(out2, (FLOAT4)0); + out3 = fmax(out3, (FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); + out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); + out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); + out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); +#endif + + const int out_x_base = mul24(out_channel_block_idx, output_shape.y); + int out_x_idx = out_height_block_idx << 2; + + const int remain = output_shape.y - out_x_idx; + int output_idx = out_x_base + out_x_idx; + if (remain >= 4) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); + WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3); + } else if (remain == 3) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); + } else if (remain == 2) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); + } else if (remain == 1) { + WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); + } +} + +__kernel +#ifdef SET_ATTRIBUTE +__attribute__((work_group_size_hint(16, 16, 1))) +#endif +void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *kernel_ptr, + __global const float *dequantScaleOffset, +#else + __global const uchar *kernel_ptr, + __global const float *dequantScaleOffset, +#endif +#ifdef BIAS + __read_only image2d_t bias, +#endif + __write_only image2d_t output, + __private const int2 input_shape, + __private const int in_channel_block_length, + __private const int2 output_shape, + __private const int2 weights_shape, + __private const int2 stride_shape, + __private const int2 padding_shape, + __private const int2 dilation_shape, + __private const int out_width_blocks, + __private const int out_channel_blocks, + __private const int out_height_blocks + ,__private const int blockDim + ,__private const int inChannel +) { + + const int output_channel_width_idx = get_global_id(0); + const int output_batch_height_idx = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); + + const int out_channel_block_idx = (output_channel_width_idx / out_width_blocks) << 1; + const int out_width_block_idx = output_channel_width_idx % out_width_blocks; + const int out_height_block_idx = (output_batch_height_idx % out_height_blocks); + const int out_batch_block_idx = output_batch_height_idx / out_height_blocks; + +#ifdef BIAS + FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_block_idx, 0)); + FLOAT4 out4 = RI_F(bias, SAMPLER, (int2)(out_channel_block_idx + 1, 0)); +#else + FLOAT4 out0 = (FLOAT4)0; + FLOAT4 out4 = (FLOAT4)0; +#endif + FLOAT4 out1 = out0; + FLOAT4 out2 = out0; + FLOAT4 out3 = out0; + FLOAT4 out5 = out4; + FLOAT4 out6 = out4; + FLOAT4 out7 = out4; + + const int weight_oc_offset = weights_shape.x * weights_shape.y * 4; + const int weight_ic_offset = out_channel_blocks * weight_oc_offset; + + int in_width0 = mad24(out_width_block_idx, stride_shape.y, -padding_shape.y); + int in_height0 = mad24(out_height_block_idx, stride_shape.x<<2, -padding_shape.x); + int in_height1 = in_height0 + stride_shape.x; + int in_height2 = in_height1 + stride_shape.x; + int in_height3 = in_height2 + stride_shape.x; + int weight_size = mul24(weights_shape.y, weights_shape.x); + + const int weights_h_idx = mul24(out_channel_block_idx, weight_size); + const int batch_idx = mul24(out_batch_block_idx, input_shape.x); + + FLOAT4 in0, in1, in2, in3; + FLOAT4 weights0, weights1, weights2, weights3, weights4, weights5, weights6, weights7; + for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { + int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); + COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); + COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx + 1, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); + COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); + + const int in_idx = mul24(in_channel_block_idx, input_shape.y); + int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4; + + for (int iy = 0; iy < weights_shape.x * dilation_shape.x; iy += dilation_shape.x) { + int h0 = select(in_height0 + iy + batch_idx, -1, (in_height0 + iy < 0 || in_height0 + iy >= input_shape.x)); + int h1 = select(in_height1 + iy + batch_idx, -1, (in_height1 + iy < 0 || in_height1 + iy >= input_shape.x)); + int h2 = select(in_height2 + iy + batch_idx, -1, (in_height2 + iy < 0 || in_height2 + iy >= input_shape.x)); + int h3 = select(in_height3 + iy + batch_idx, -1, (in_height3 + iy < 0 || in_height3 + iy >= input_shape.x)); + for (int ix = 0; ix < weights_shape.y * dilation_shape.y; ix += dilation_shape.y) { + int w0 = select(in_width0 + ix + in_idx, -1, (in_width0 + ix < 0 || in_width0 + ix >= input_shape.y)); + + in0 = RI_F(input, SAMPLER, (int2)(w0, h0)); + in1 = RI_F(input, SAMPLER, (int2)(w0, h1)); + in2 = RI_F(input, SAMPLER, (int2)(w0, h2)); + in3 = RI_F(input, SAMPLER, (int2)(w0, h3)); + +#if QUANT_BIT == 8 + char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); + char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset); + char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset*2); + char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_ic_offset*3); + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + #ifdef CHANNEL_BOUNDARY_PROTECT + charWeight0 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + charWeight1 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset); + charWeight2 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2); + charWeight3 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3); + + #else + charWeight0 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset); + charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2); + charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3); + #endif + weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1); + weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1); + weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1); + weights7 = mad(CONVERT_FLOAT4(charWeight3), scale1, offset1); + weight_offset += 4; +#else + uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); + uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset/2); + uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset*2/2); + uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_ic_offset*3/2); + char4 charWeight0 = (char4)(0, 0, 0, 0); + char4 charWeight1 = (char4)(0, 0, 0, 0); + char4 charWeight2 = (char4)(0, 0, 0, 0); + char4 charWeight3 = (char4)(0, 0, 0, 0); + charWeight0.x = (charWeightInt40.s0 >> 4) - 8; + charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; + charWeight0.z = (charWeightInt40.s1 >> 4) - 8; + charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; + charWeight1.x = (charWeightInt41.s0 >> 4) - 8; + charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; + charWeight1.z = (charWeightInt41.s1 >> 4) - 8; + charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; + charWeight2.x = (charWeightInt42.s0 >> 4) - 8; + charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; + charWeight2.z = (charWeightInt42.s1 >> 4) - 8; + charWeight2.w = (charWeightInt42.s1 & MOD_NUM)- 8; + charWeight3.x = (charWeightInt43.s0 >> 4) - 8; + charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; + charWeight3.z = (charWeightInt43.s1 >> 4) - 8; + charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); + charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2); + charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2); + charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2); + charWeight0 = (char4)(0, 0, 0, 0); + charWeight1 = (char4)(0, 0, 0, 0); + charWeight2 = (char4)(0, 0, 0, 0); + charWeight3 = (char4)(0, 0, 0, 0); + charWeight0.x = (charWeightInt40.s0 >> 4) - 8; + charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; + charWeight0.z = (charWeightInt40.s1 >> 4) - 8; + charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; + charWeight1.x = (charWeightInt41.s0 >> 4) - 8; + charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; + charWeight1.z = (charWeightInt41.s1 >> 4) - 8; + charWeight1.w = (charWeightInt41.s1 & MOD_NUM)- 8; + charWeight2.x = (charWeightInt42.s0 >> 4) - 8; + charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; + charWeight2.z = (charWeightInt42.s1 >> 4) - 8; + charWeight2.w = (charWeightInt42.s1 & MOD_NUM)- 8; + charWeight3.x = (charWeightInt43.s0 >> 4) - 8; + charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; + charWeight3.z = (charWeightInt43.s1 >> 4) - 8; + charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; + weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1); + weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1); + weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1); + weights7 = mad(CONVERT_FLOAT4(charWeight3), scale1, offset1); + weight_offset += 4; +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + PADZEROSVEC(in_channel_block_idx, inChannel, weights4, weights5, weights6, weights7); + + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + CALCULATE_OUTPUT_WEIGHTS4(4, 0); + CALCULATE_OUTPUT_WEIGHTS4(5, 1); + CALCULATE_OUTPUT_WEIGHTS4(6, 2); + CALCULATE_OUTPUT_WEIGHTS4(7, 3); + } + } + } + +#ifdef RELU + out0 = fmax(out0, (FLOAT4)0); + out1 = fmax(out1, (FLOAT4)0); + out2 = fmax(out2, (FLOAT4)0); + out3 = fmax(out3, (FLOAT4)0); + out4 = fmax(out4, (FLOAT4)0); + out5 = fmax(out5, (FLOAT4)0); + out6 = fmax(out6, (FLOAT4)0); + out7 = fmax(out7, (FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); + out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); + out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); + out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); + out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6); + out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6); + out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6); + out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6); +#endif + + const int out_x_base = mul24(out_channel_block_idx, output_shape.y); + const int out_y_base = mul24(out_batch_block_idx, output_shape.x); + int out_x_idx = out_width_block_idx; + int out_y_idx = out_height_block_idx << 2; + + const int remain_y = output_shape.x - out_y_idx; + int output_idx = out_x_base + out_x_idx; + int output_idy = out_y_base + out_y_idx; + + if(remain_y >= 4){ + WI_F(output, (int2)(output_idx, output_idy), out0); + WI_F(output, (int2)(output_idx, output_idy + 1), out1); + WI_F(output, (int2)(output_idx, output_idy + 2), out2); + WI_F(output, (int2)(output_idx, output_idy + 3), out3); + }else if(remain_y == 3){ + WI_F(output, (int2)(output_idx, output_idy), out0); + WI_F(output, (int2)(output_idx, output_idy + 1), out1); + WI_F(output, (int2)(output_idx, output_idy + 2), out2); + }else if(remain_y == 2){ + WI_F(output, (int2)(output_idx, output_idy), out0); + WI_F(output, (int2)(output_idx, output_idy + 1), out1); + }else if(remain_y == 1){ + WI_F(output, (int2)(output_idx, output_idy), out0); + } + + if(out_channel_block_idx + 1 >= out_channel_blocks) { + return; + } + output_idx += output_shape.y; + if(remain_y >= 4){ + WI_F(output, (int2)(output_idx, output_idy), out4); + WI_F(output, (int2)(output_idx, output_idy + 1), out5); + WI_F(output, (int2)(output_idx, output_idy + 2), out6); + WI_F(output, (int2)(output_idx, output_idy + 3), out7); + }else if(remain_y == 3){ + WI_F(output, (int2)(output_idx, output_idy), out4); + WI_F(output, (int2)(output_idx, output_idy + 1), out5); + WI_F(output, (int2)(output_idx, output_idy + 2), out6); + }else if(remain_y == 2){ + WI_F(output, (int2)(output_idx, output_idy), out4); + WI_F(output, (int2)(output_idx, output_idy + 1), out5); + }else if(remain_y == 1){ + WI_F(output, (int2)(output_idx, output_idy), out4); + } +} + +__kernel +#ifdef SET_ATTRIBUTE +__attribute__((work_group_size_hint(16, 16, 1))) +#endif +void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *kernel_ptr, + __global const float *dequantScaleOffset, +#else + __global const uchar *kernel_ptr, + __global const float *dequantScaleOffset, +#endif +#ifdef BIAS + __read_only image2d_t bias, +#endif + __write_only image2d_t output, + __private const int2 input_shape, + __private const int in_channel_block_length, + __private const int2 output_shape, + __private const int2 weights_shape, + __private const int2 stride_shape, + __private const int2 padding_shape, + __private const int2 dilation_shape, + __private const int out_width_blocks, + __private const int out_channel_blocks, + __private const int out_height_blocks + ,__private const int blockDim + ,__private const int inChannel +) { + + const int output_channel_width_idx = get_global_id(0); + const int output_batch_height_idx = get_global_id(1); + DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); + + const int out_channel_block_idx = output_channel_width_idx / out_width_blocks; + const int out_width_block_idx = output_channel_width_idx % out_width_blocks; + const int out_height_block_idx = (output_batch_height_idx % out_height_blocks); + const int out_batch_block_idx = output_batch_height_idx / out_height_blocks; + +#ifdef BIAS + FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_block_idx, 0)); +#else + FLOAT4 out0 = (FLOAT4)0; +#endif + FLOAT4 out1 = out0; + FLOAT4 out2 = out0; + FLOAT4 out3 = out0; + + int in_width0 = mad24(out_width_block_idx, stride_shape.y, -padding_shape.y); + int in_height0 = mad24(out_height_block_idx, stride_shape.x<<2, -padding_shape.x); + int in_height1 = in_height0 + stride_shape.x; + int in_height2 = in_height1 + stride_shape.x; + int in_height3 = in_height2 + stride_shape.x; + int weight_size = mul24(weights_shape.y, weights_shape.x); + + const int weights_h_idx = mul24(out_channel_block_idx, weight_size); + const int batch_idx = mul24(out_batch_block_idx, input_shape.x); + + FLOAT4 in0, in1, in2, in3; + FLOAT4 weights0, weights1, weights2, weights3; + const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4; + for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { + int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); + COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); + const int in_idx = mul24(in_channel_block_idx, input_shape.y); + int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4; + for (int iy = 0; iy < weights_shape.x * dilation_shape.x; iy += dilation_shape.x) { + int h0 = select(in_height0 + iy + batch_idx, -1, (in_height0 + iy < 0 || in_height0 + iy >= input_shape.x)); + int h1 = select(in_height1 + iy + batch_idx, -1, (in_height1 + iy < 0 || in_height1 + iy >= input_shape.x)); + int h2 = select(in_height2 + iy + batch_idx, -1, (in_height2 + iy < 0 || in_height2 + iy >= input_shape.x)); + int h3 = select(in_height3 + iy + batch_idx, -1, (in_height3 + iy < 0 || in_height3 + iy >= input_shape.x)); + for (int ix = 0; ix < weights_shape.y * dilation_shape.y; ix += dilation_shape.y) { + int w0 = select(in_width0 + ix + in_idx, -1, (in_width0 + ix < 0 || in_width0 + ix >= input_shape.y)); + + in0 = RI_F(input, SAMPLER, (int2)(w0, h0)); + in1 = RI_F(input, SAMPLER, (int2)(w0, h1)); + in2 = RI_F(input, SAMPLER, (int2)(w0, h2)); + in3 = RI_F(input, SAMPLER, (int2)(w0, h3)); + +#if QUANT_BIT == 8 + char4 charWeight0 = vload4(0, kernel_ptr+weight_offset); + char4 charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + char4 charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*2); + char4 charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset*3); + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#else + uchar2 charWeightInt40 = vload2(0, kernel_ptr+weight_offset/2); + uchar2 charWeightInt41 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset/2); + uchar2 charWeightInt42 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*2/2); + uchar2 charWeightInt43 = vload2(0, kernel_ptr+weight_offset/2+weight_oc_offset*3/2); + char4 charWeight0 = (char4)(0, 0, 0, 0); + char4 charWeight1 = (char4)(0, 0, 0, 0); + char4 charWeight2 = (char4)(0, 0, 0, 0); + char4 charWeight3 = (char4)(0, 0, 0, 0); + charWeight0.x = (charWeightInt40.s0 >> 4) - 8; + charWeight0.y = (charWeightInt40.s0 & MOD_NUM) - 8; + charWeight0.z = (charWeightInt40.s1 >> 4) - 8; + charWeight0.w = (charWeightInt40.s1 & MOD_NUM) - 8; + charWeight1.x = (charWeightInt41.s0 >> 4) - 8; + charWeight1.y = (charWeightInt41.s0 & MOD_NUM) - 8; + charWeight1.z = (charWeightInt41.s1 >> 4) - 8; + charWeight1.w = (charWeightInt41.s1 & MOD_NUM) - 8; + charWeight2.x = (charWeightInt42.s0 >> 4) - 8; + charWeight2.y = (charWeightInt42.s0 & MOD_NUM) - 8; + charWeight2.z = (charWeightInt42.s1 >> 4) - 8; + charWeight2.w = (charWeightInt42.s1 & MOD_NUM) - 8; + charWeight3.x = (charWeightInt43.s0 >> 4) - 8; + charWeight3.y = (charWeightInt43.s0 & MOD_NUM) - 8; + charWeight3.z = (charWeightInt43.s1 >> 4) - 8; + charWeight3.w = (charWeightInt43.s1 & MOD_NUM) - 8; + weights0 = mad(CONVERT_FLOAT4(charWeight0), scale0, offset0); + weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); + weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); + weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + weight_offset += 4; +#endif + PADZEROSVEC(in_channel_block_idx, inChannel, weights0, weights1, weights2, weights3); + + CALCULATE_OUTPUT(0); + CALCULATE_OUTPUT(1); + CALCULATE_OUTPUT(2); + CALCULATE_OUTPUT(3); + } + } + } + +#ifdef RELU + out0 = fmax(out0, (FLOAT4)0); + out1 = fmax(out1, (FLOAT4)0); + out2 = fmax(out2, (FLOAT4)0); + out3 = fmax(out3, (FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); + out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); + out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); + out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); +#endif + + const int out_x_base = mul24(out_channel_block_idx, output_shape.y); + const int out_y_base = mul24(out_batch_block_idx, output_shape.x); + int out_x_idx = out_width_block_idx; + int out_y_idx = out_height_block_idx << 2; + + const int remain_y = output_shape.x - out_y_idx; + int output_idx = out_x_base + out_x_idx; + int output_idy = out_y_base + out_y_idx; + + if(remain_y >= 4){ + WI_F(output, (int2)(output_idx, output_idy), out0); + WI_F(output, (int2)(output_idx, output_idy + 1), out1); + WI_F(output, (int2)(output_idx, output_idy + 2), out2); + WI_F(output, (int2)(output_idx, output_idy + 3), out3); + }else if(remain_y == 3){ + WI_F(output, (int2)(output_idx, output_idy), out0); + WI_F(output, (int2)(output_idx, output_idy + 1), out1); + WI_F(output, (int2)(output_idx, output_idy + 2), out2); + }else if(remain_y == 2){ + WI_F(output, (int2)(output_idx, output_idy), out0); + WI_F(output, (int2)(output_idx, output_idy + 1), out1); + }else{ + WI_F(output, (int2)(output_idx, output_idy), out0); + } +} diff --git a/source/backend/opencl/execution/cl/conv_2d_int_buf.cl b/source/backend/opencl/execution/cl/conv_2d_int_buf.cl index 5d171dd2..28cbbfb0 100644 --- a/source/backend/opencl/execution/cl/conv_2d_int_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_int_buf.cl @@ -23,7 +23,7 @@ __kernel void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 __global const char *weight, #else __global const uchar *weight, @@ -88,7 +88,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y; -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 char4 charWeight0 = vload4(filter_w_inc, weight+weight_offset); char4 charWeight1 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset); char4 charWeight2 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset*2); @@ -155,7 +155,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS __kernel void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 __global const char *weight, #else __global const uchar *weight, @@ -224,7 +224,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4((in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4((in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 char4 charWeight0 = vload4(0, weight+weight_offset); char4 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset); char4 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset*2); @@ -303,7 +303,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS __kernel void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 __global const char *weight, #else __global const uchar *weight, @@ -380,7 +380,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4((in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base)); COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4((in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 char4 charWeight0 = vload4(0, weight+weight_offset); char4 charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset); char4 charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset*2); @@ -482,7 +482,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS __kernel void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 __global const char *weight, #else __global const uchar *weight, @@ -579,7 +579,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4((in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base)); COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4((in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 char4 charWeight0 = vload4(0, weight+weight_offset); char4 charWeight1 = vload4(0, weight+weight_offset+weight_ic_offset); char4 charWeight2 = vload4(0, weight+weight_offset+weight_ic_offset*2); @@ -640,7 +640,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS out3 = mad(in3.z, weight2, out3); out3 = mad(in3.w, weight3, out3); -#if (defined USE_LOW_BIT_WEIGHT_INT8) +#if QUANT_BIT == 8 #ifdef CHANNEL_BOUNDARY_PROTECT charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset); charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); diff --git a/source/backend/opencl/execution/cl/conv_2d_int_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/conv_2d_int_buf_mnn_cl.cpp index 49330434..1a1b7f7e 100644 --- a/source/backend/opencl/execution/cl/conv_2d_int_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/conv_2d_int_buf_mnn_cl.cpp @@ -16,7 +16,7 @@ const char* conv_2d_int_buf = "__kernel\n" "void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS\n" " __global const FLOAT *input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " __global const char *weight,\n" "#else\n" " __global const uchar *weight,\n" @@ -77,7 +77,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n" " \n" " const int filter_w_inc=(ix-in_w_idx_start)/dilate_hw.y;\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " char4 charWeight0=vload4(filter_w_inc,weight+weight_offset);\n" " char4 charWeight1=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(filter_w_inc,weight+weight_offset+weight_oc_offset*2);\n" @@ -139,7 +139,7 @@ const char* conv_2d_int_buf = "__kernel\n" "void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS\n" " __global const FLOAT *input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " __global const char *weight,\n" "#else\n" " __global const uchar *weight,\n" @@ -202,7 +202,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 in0=CONVERT_COMPUTE_FLOAT4((in_w0_idx<0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx,input+inp_offset_base));\n" " COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n" " \n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " char4 charWeight0=vload4(0,weight+weight_offset);\n" " char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n" @@ -277,7 +277,7 @@ const char* conv_2d_int_buf = "__kernel\n" "void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS\n" " __global const FLOAT *input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " __global const char *weight,\n" "#else\n" " __global const uchar *weight,\n" @@ -345,7 +345,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n" " COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx,input+inp_offset_base));\n" " COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx,input+inp_offset_base));\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " char4 charWeight0=vload4(0,weight+weight_offset);\n" " char4 charWeight1=vload4(0,weight+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(0,weight+weight_offset+weight_oc_offset*2);\n" @@ -442,7 +442,7 @@ const char* conv_2d_int_buf = "__kernel\n" "void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS\n" " __global const FLOAT *input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " __global const char *weight,\n" "#else\n" " __global const uchar *weight,\n" @@ -531,7 +531,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 in1=CONVERT_COMPUTE_FLOAT4((in_w1_idx<0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx,input+inp_offset_base));\n" " COMPUTE_FLOAT4 in2=CONVERT_COMPUTE_FLOAT4((in_w2_idx<0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx,input+inp_offset_base));\n" " COMPUTE_FLOAT4 in3=CONVERT_COMPUTE_FLOAT4((in_w3_idx<0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx,input+inp_offset_base));\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " char4 charWeight0=vload4(0,weight+weight_offset);\n" " char4 charWeight1=vload4(0,weight+weight_offset+weight_ic_offset);\n" " char4 charWeight2=vload4(0,weight+weight_offset+weight_ic_offset*2);\n" @@ -591,7 +591,7 @@ const char* conv_2d_int_buf = " out3=mad(in3.z,weight2,out3);\n" " out3=mad(in3.w,weight3,out3);\n" " \n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#if QUANT_BIT == 8\n" " #ifdef CHANNEL_BOUNDARY_PROTECT\n" " charWeight0=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset);\n" " charWeight1=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" diff --git a/source/backend/opencl/execution/cl/conv_2d_int_mnn_cl.cpp b/source/backend/opencl/execution/cl/conv_2d_int_mnn_cl.cpp new file mode 100644 index 00000000..354bdb6d --- /dev/null +++ b/source/backend/opencl/execution/cl/conv_2d_int_mnn_cl.cpp @@ -0,0 +1,1080 @@ +#include "opencl_source_map.hpp" +namespace MNN { +const char* conv_2d_int = +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=RI_F(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n" +"#define CALCULATE_OUTPUT(i) "" out##i = mad(in##i.x, weights0, out##i); "" out##i = mad(in##i.y, weights1, out##i); "" out##i = mad(in##i.z, weights2, out##i); "" out##i=mad(in##i.w,weights3,out##i); \n" +"#define CALCULATE_OUTPUT_WEIGHTS4(i, j) "" out##i = mad(in##j.x, weights4, out##i); "" out##i = mad(in##j.y, weights5, out##i); "" out##i = mad(in##j.z, weights6, out##i); "" out##i=mad(in##j.w,weights7,out##i);\n" +"#define CALCULATE_OUTPUT_OPT(i) "" out##i = mad(in_sm##i[local_idx].x, weights0, out##i); "" out##i = mad(in_sm##i[local_idx].y, weights1, out##i); "" out##i = mad(in_sm##i[local_idx].z, weights2, out##i); "" out##i=mad(in_sm##i[local_idx].w,weights3,out##i); \n" +"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n" +"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" +"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" +"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"#define UNIT 4\n" +"#define MOD_NUM 15\n" +"#ifdef INPUT_CHANNEL_LEAVE\n" +" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3= input_shape.y);\n" +" intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n" +" intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n" +" intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n" +"#endif\n" +" int batch_index=output_batch_height_idx/output_shape.x;\n" +" int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n" +" FLOAT4 in0;\n" +" FLOAT4 in1;\n" +" FLOAT4 in2;\n" +" FLOAT4 in3;\n" +" FLOAT4 weights0;\n" +" FLOAT4 weights1;\n" +" FLOAT4 weights2;\n" +" FLOAT4 weights3;\n" +" int weight_offset=output_channel_block_idx*in_channel_block*4*4;\n" +" for (int in_channel_block_idx=0; in_channel_block_idx> 4)-8;\n" +" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n" +" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n" +" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n" +" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n" +" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n" +" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n" +" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)- 8;\n" +" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n" +" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n" +" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n" +" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n" +" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n" +" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n" +" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n" +" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n" +" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n" +" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n" +" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" }\n" +"#ifdef RELU\n" +" out0=fmax(out0,(FLOAT4)0);\n" +" out1=fmax(out1,(FLOAT4)0);\n" +" out2=fmax(out2,(FLOAT4)0);\n" +" out3=fmax(out3,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" +" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" +" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" +" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" const int out_x_base=mul24(output_channel_block_idx,output_shape.y);\n" +" int out_x_idx=output_width_block_idx << 2;\n" +" const int remain=output_shape.y-out_x_idx;\n" +" int output_idx=out_x_base+out_x_idx;\n" +" if (remain >= 4) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" +" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n" +" } else if (remain == 3) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" +" } else if (remain == 2) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" } else if (remain == 1) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" }\n" +"}\n" +"__kernel\n" +"#ifdef SET_ATTRIBUTE\n" +"__attribute__((work_group_size_hint(16,16,1)))\n" +"#endif\n" +"void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" +"#if QUANT_BIT == 8\n" +" __global const char *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#else\n" +" __global const uchar *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#endif\n" +" __read_only image2d_t bias,\n" +" __write_only image2d_t output,\n" +" __private const int2 input_shape,\n" +" __private const int in_channel_block,__private const int2 output_shape,\n" +" __private const int2 stride_shape,\n" +" __private const int output_width_4,\n" +" __private const int out_channel_blocks\n" +" ,__private const int blockDim\n" +" ,__private const int inChannel\n" +") {\n" +" const int output_channel_width_idx=get_global_id(0);\n" +" const int output_batch_height_idx=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" +" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n" +" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n" +" const int output_channel_idx=output_channel_block_idx << 1;\n" +"#if QUANT_BIT == 4\n" +" int weight_ic_offset=output_channel_block_idx*16;\n" +" int weight_oc_offset=out_channel_blocks*8;\n" +"#else\n" +" int weight_ic_offset=output_channel_block_idx*32;\n" +" int weight_oc_offset=out_channel_blocks*16;\n" +"#endif\n" +" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_idx,0));\n" +" FLOAT4 out1=out0;\n" +" FLOAT4 out2=out0;\n" +" FLOAT4 out3=out0;\n" +" \n" +" FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(output_channel_idx+1,0));\n" +" FLOAT4 out5=out4;\n" +" FLOAT4 out6=out4;\n" +" FLOAT4 out7=out4;\n" +"#ifdef MNN_CONV_S1D1\n" +" int intput_width_idx0=output_width_block_idx << 2;\n" +" int intput_width_idx1=intput_width_idx0+1;\n" +" int intput_width_idx2=intput_width_idx0+2;\n" +" int intput_width_idx3=intput_width_idx0+3;\n" +"#else\n" +" int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n" +" int intput_width_idx1=intput_width_idx0+stride_shape.y;\n" +" int intput_width_idx2=intput_width_idx1+stride_shape.y;\n" +" int intput_width_idx3=intput_width_idx2+stride_shape.y;\n" +" intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n" +" intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n" +" intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n" +" intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n" +"#endif\n" +" int batch_index=output_batch_height_idx/output_shape.x;\n" +" int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n" +" FLOAT4 in0;\n" +" FLOAT4 in1;\n" +" FLOAT4 in2;\n" +" FLOAT4 in3;\n" +" FLOAT4 weights0;\n" +" FLOAT4 weights1;\n" +" FLOAT4 weights2;\n" +" FLOAT4 weights3;\n" +" FLOAT4 weights4;\n" +" FLOAT4 weights5;\n" +" FLOAT4 weights6;\n" +" FLOAT4 weights7;\n" +" int weight_offset=output_channel_idx*in_channel_block*4*4;\n" +" int weight_offset1=weight_offset+in_channel_block*4*4;\n" +" for (int in_channel_block_idx=0; in_channel_block_idx= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" +" #else\n" +" FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" +" #endif\n" +" FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n" +" FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n" +" FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n" +" FLOAT4 weights3=CONVERT_FLOAT4(weightsInt80.scdef)*scale0+offset0;\n" +" FLOAT4 weights4=CONVERT_FLOAT4(weightsInt81.s0123)*scale1+offset1;\n" +" FLOAT4 weights5=CONVERT_FLOAT4(weightsInt81.s4567)*scale1+offset1;\n" +" FLOAT4 weights6=CONVERT_FLOAT4(weightsInt81.s89ab)*scale1+offset1;\n" +" FLOAT4 weights7=CONVERT_FLOAT4(weightsInt81.scdef)*scale1+offset1;\n" +"#else\n" +" uchar16 charWeightsInt4=vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n" +" char4 charWeights0=(char4)(0,0,0,0);\n" +" char4 charWeights1=(char4)(0,0,0,0);\n" +" char4 charWeights2=(char4)(0,0,0,0);\n" +" char4 charWeights3=(char4)(0,0,0,0);\n" +" char4 charWeights4=(char4)(0,0,0,0);\n" +" char4 charWeights5=(char4)(0,0,0,0);\n" +" char4 charWeights6=(char4)(0,0,0,0);\n" +" char4 charWeights7=(char4)(0,0,0,0);\n" +" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n" +" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n" +" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n" +" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n" +" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n" +" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n" +" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n" +" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)-8;\n" +" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n" +" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n" +" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n" +" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n" +" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n" +" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n" +" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n" +" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n" +" charWeights4.x=(charWeightsInt4.s8 >> 4)-8;\n" +" charWeights4.y=(charWeightsInt4.s8 & MOD_NUM)-8;\n" +" charWeights4.z=(charWeightsInt4.s9 >> 4)-8;\n" +" charWeights4.w=(charWeightsInt4.s9 & MOD_NUM)-8;\n" +" charWeights5.x=(charWeightsInt4.sa >> 4)-8;\n" +" charWeights5.y=(charWeightsInt4.sa & MOD_NUM)-8;\n" +" charWeights5.z=(charWeightsInt4.sb >> 4)-8;\n" +" charWeights5.w=(charWeightsInt4.sb & MOD_NUM)-8;\n" +" charWeights6.x=(charWeightsInt4.sc >> 4)-8;\n" +" charWeights6.y=(charWeightsInt4.sc & MOD_NUM)-8;\n" +" charWeights6.z=(charWeightsInt4.sd >> 4)-8;\n" +" charWeights6.w=(charWeightsInt4.sd & MOD_NUM)-8;\n" +" charWeights7.x=(charWeightsInt4.se >> 4)-8;\n" +" charWeights7.y=(charWeightsInt4.se & MOD_NUM)-8;\n" +" charWeights7.z=(charWeightsInt4.sf >> 4)-8;\n" +" charWeights7.w=(charWeightsInt4.sf & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n" +" weights4=mad(CONVERT_FLOAT4(charWeights4),scale1,offset1);\n" +" weights5=mad(CONVERT_FLOAT4(charWeights5),scale1,offset1);\n" +" weights6=mad(CONVERT_FLOAT4(charWeights6),scale1,offset1);\n" +" weights7=mad(CONVERT_FLOAT4(charWeights7),scale1,offset1);\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" \n" +" CALCULATE_OUTPUT_WEIGHTS4(4,0);\n" +" CALCULATE_OUTPUT_WEIGHTS4(5,1);\n" +" CALCULATE_OUTPUT_WEIGHTS4(6,2);\n" +" CALCULATE_OUTPUT_WEIGHTS4(7,3);\n" +" }\n" +"#ifdef RELU\n" +" out0=fmax(out0,(FLOAT4)0);\n" +" out1=fmax(out1,(FLOAT4)0);\n" +" out2=fmax(out2,(FLOAT4)0);\n" +" out3=fmax(out3,(FLOAT4)0);\n" +" out4=fmax(out4,(FLOAT4)0);\n" +" out5=fmax(out5,(FLOAT4)0);\n" +" out6=fmax(out6,(FLOAT4)0);\n" +" out7=fmax(out7,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" +" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" +" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" +" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" +" out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n" +" out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n" +" out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n" +" out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" const int out_x_base=mul24(output_channel_idx,output_shape.y);\n" +" int out_x_idx=output_width_block_idx << 2;\n" +" const int remain=output_shape.y-out_x_idx;\n" +" int output_idx=out_x_base+out_x_idx;\n" +" if (remain >= 4) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" +" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n" +" } else if (remain == 3) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" +" } else if (remain == 2) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" } else if (remain == 1) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" }\n" +" \n" +" if(output_channel_idx+1 >= out_channel_blocks)\n" +" return;\n" +" output_idx += output_shape.y;\n" +" if (remain >= 4) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n" +" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out7);\n" +" } else if (remain == 3) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n" +" } else if (remain == 2) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n" +" } else if (remain == 1) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" +" }\n" +"}\n" +"__kernel\n" +"#ifdef SET_ATTRIBUTE\n" +"__attribute__((work_group_size_hint(16,16,1)))\n" +"#endif\n" +"void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" +"#if QUANT_BIT == 8\n" +" __global const char *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#else\n" +" __global const uchar *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#endif\n" +"#ifdef BIAS\n" +" __read_only image2d_t bias,\n" +"#endif\n" +" __write_only image2d_t output,\n" +" __private const int2 input_shape,\n" +" __private const int in_channel_block_length,\n" +" __private const int2 output_shape,\n" +" __private const int2 weights_shape,\n" +" __private const int2 stride_shape,\n" +" __private const int2 padding_shape,\n" +" __private const int2 dilation_shape,\n" +" __private const int out_width_blocks,\n" +" __private const int out_channel_blocks,\n" +" __private const int out_height_blocks\n" +" ,__private const int blockDim\n" +" ,__private const int inChannel\n" +") {\n" +" const int output_channel_width_idx=get_global_id(0);\n" +" const int output_batch_height_idx=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" +" const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n" +" const int out_height_block_idx=output_channel_width_idx % out_width_blocks;\n" +"#ifdef BIAS\n" +" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n" +"#else\n" +" FLOAT4 out0=(FLOAT4)0;\n" +"#endif\n" +" FLOAT4 out1=out0;\n" +" FLOAT4 out2=out0;\n" +" FLOAT4 out3=out0;\n" +" int in_width0=mad24(out_height_block_idx,stride_shape.y<<2,-padding_shape.y);\n" +" int in_width1=in_width0+stride_shape.y;\n" +" int in_width2=in_width0+stride_shape.y*2;\n" +" int in_width3=in_width0+stride_shape.y*3;\n" +" \n" +"#ifdef MNN_CONV_S1D1\n" +" const int height_start=mad24((output_batch_height_idx % output_shape.x),1,-padding_shape.x);\n" +" const int kh_start=select(0,(-height_start),height_start<0);\n" +" int in_height_start=kh_start+height_start;\n" +" int in_height_end=min(weights_shape.x+height_start,input_shape.x);\n" +" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n" +" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start),height_start<0),weights_shape.y);\n" +"#else\n" +" const int height_start=mad24((output_batch_height_idx % output_shape.x),stride_shape.x,-padding_shape.x);\n" +" const int kh_start=select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0);\n" +" int in_height_start=mad24(kh_start,dilation_shape.x,height_start);\n" +" int in_height_end=min(mad24(weights_shape.x,dilation_shape.x,height_start),input_shape.x);\n" +" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n" +" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0),weights_shape.y);\n" +"#endif\n" +" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n" +" FLOAT4 in0,in1,in2,in3;\n" +" FLOAT4 weights0,weights1,weights2,weights3;\n" +" for (int in_channel_block_idx=0; in_channel_block_idx> 4)-8;\n" +" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" +" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" +" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" +" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" +" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" +" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" +" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" +" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" +" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" +" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" +" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" +" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" +" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" +" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" +" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" weight_offset += 4;\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" }\n" +" for (int w=1; w> 4)-8;\n" +" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" +" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" +" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" +" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" +" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" +" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" +" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" +" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" +" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" +" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" +" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" +" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" +" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" +" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" +" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" weight_offset += 4;\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" }\n" +"#else\n" +" for (int w=0; w> 4)-8;\n" +" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" +" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" +" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" +" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" +" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" +" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" +" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" +" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" +" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" +" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" +" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" +" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" +" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" +" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" +" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" weight_offset += 4;\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" }\n" +"#endif\n" +" }\n" +" }\n" +"#ifdef RELU\n" +" out0=fmax(out0,(FLOAT4)0);\n" +" out1=fmax(out1,(FLOAT4)0);\n" +" out2=fmax(out2,(FLOAT4)0);\n" +" out3=fmax(out3,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" +" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" +" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" +" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n" +" int out_x_idx=out_height_block_idx << 2;\n" +" const int remain=output_shape.y-out_x_idx;\n" +" int output_idx=out_x_base+out_x_idx;\n" +" if (remain >= 4) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" +" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n" +" } else if (remain == 3) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" +" } else if (remain == 2) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" +" } else if (remain == 1) {\n" +" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" +" }\n" +"}\n" +"__kernel\n" +"#ifdef SET_ATTRIBUTE\n" +"__attribute__((work_group_size_hint(16,16,1)))\n" +"#endif\n" +"void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" +"#if QUANT_BIT == 8\n" +" __global const char *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#else\n" +" __global const uchar *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#endif\n" +"#ifdef BIAS\n" +" __read_only image2d_t bias,\n" +"#endif\n" +" __write_only image2d_t output,\n" +" __private const int2 input_shape,\n" +" __private const int in_channel_block_length,\n" +" __private const int2 output_shape,\n" +" __private const int2 weights_shape,\n" +" __private const int2 stride_shape,\n" +" __private const int2 padding_shape,\n" +" __private const int2 dilation_shape,\n" +" __private const int out_width_blocks,\n" +" __private const int out_channel_blocks,\n" +" __private const int out_height_blocks\n" +" ,__private const int blockDim\n" +" ,__private const int inChannel\n" +") {\n" +" const int output_channel_width_idx=get_global_id(0);\n" +" const int output_batch_height_idx=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" +" const int out_channel_block_idx=(output_channel_width_idx/out_width_blocks) << 1;\n" +" const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n" +" const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n" +" const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n" +"#ifdef BIAS\n" +" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n" +" FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx+1,0));\n" +"#else\n" +" FLOAT4 out0=(FLOAT4)0;\n" +" FLOAT4 out4=(FLOAT4)0;\n" +"#endif\n" +" FLOAT4 out1=out0;\n" +" FLOAT4 out2=out0;\n" +" FLOAT4 out3=out0;\n" +" FLOAT4 out5=out4;\n" +" FLOAT4 out6=out4;\n" +" FLOAT4 out7=out4;\n" +" const int weight_oc_offset=weights_shape.x*weights_shape.y*4;\n" +" const int weight_ic_offset=out_channel_blocks*weight_oc_offset;\n" +" int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n" +" int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n" +" int in_height1=in_height0+stride_shape.x;\n" +" int in_height2=in_height1+stride_shape.x;\n" +" int in_height3=in_height2+stride_shape.x;\n" +" int weight_size=mul24(weights_shape.y,weights_shape.x);\n" +" \n" +" const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n" +" const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n" +" \n" +" FLOAT4 in0,in1,in2,in3;\n" +" FLOAT4 weights0,weights1,weights2,weights3,weights4,weights5,weights6,weights7;\n" +" for (int in_channel_block_idx=0; in_channel_block_idx= input_shape.x));\n" +" int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n" +" int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n" +" int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n" +" for (int ix=0; ix= input_shape.y));\n" +" \n" +" in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n" +" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n" +" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n" +" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n" +"#if QUANT_BIT == 8\n" +" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" +" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_ic_offset);\n" +" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*2);\n" +" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*3);\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" charWeight0=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" +" charWeight1=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" +" charWeight2=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" +" charWeight3=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" \n" +" #else\n" +" charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" +" charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" +" charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" +" charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #endif\n" +" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" +" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" +" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" +" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n" +" weight_offset += 4;\n" +"#else\n" +" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" +" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset/2);\n" +" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*2/2);\n" +" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*3/2);\n" +" char4 charWeight0=(char4)(0,0,0,0);\n" +" char4 charWeight1=(char4)(0,0,0,0);\n" +" char4 charWeight2=(char4)(0,0,0,0);\n" +" char4 charWeight3=(char4)(0,0,0,0);\n" +" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" +" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" +" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" +" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" +" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" +" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" +" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" +" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" +" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" +" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" +" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" +" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n" +" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" +" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" +" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" +" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" charWeightInt40=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" +" charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n" +" charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n" +" charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n" +" charWeight0=(char4)(0,0,0,0);\n" +" charWeight1=(char4)(0,0,0,0);\n" +" charWeight2=(char4)(0,0,0,0);\n" +" charWeight3=(char4)(0,0,0,0);\n" +" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" +" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" +" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" +" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" +" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" +" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" +" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" +" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n" +" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" +" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" +" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" +" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n" +" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" +" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" +" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" +" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" +" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" +" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" +" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" +" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n" +"weight_offset += 4;\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n" +" \n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" CALCULATE_OUTPUT_WEIGHTS4(4,0);\n" +" CALCULATE_OUTPUT_WEIGHTS4(5,1);\n" +" CALCULATE_OUTPUT_WEIGHTS4(6,2);\n" +" CALCULATE_OUTPUT_WEIGHTS4(7,3);\n" +" }\n" +" }\n" +" }\n" +"#ifdef RELU\n" +" out0=fmax(out0,(FLOAT4)0);\n" +" out1=fmax(out1,(FLOAT4)0);\n" +" out2=fmax(out2,(FLOAT4)0);\n" +" out3=fmax(out3,(FLOAT4)0);\n" +" out4=fmax(out4,(FLOAT4)0);\n" +" out5=fmax(out5,(FLOAT4)0);\n" +" out6=fmax(out6,(FLOAT4)0);\n" +" out7=fmax(out7,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" +" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" +" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" +" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" +" out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n" +" out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n" +" out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n" +" out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n" +" const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n" +" int out_x_idx=out_width_block_idx;\n" +" int out_y_idx=out_height_block_idx << 2;\n" +" const int remain_y=output_shape.x-out_y_idx;\n" +" int output_idx=out_x_base+out_x_idx;\n" +" int output_idy=out_y_base+out_y_idx;\n" +" \n" +" if(remain_y >= 4){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" +" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" +" WI_F(output,(int2)(output_idx,output_idy+3),out3);\n" +" }else if(remain_y == 3){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" +" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" +" }else if(remain_y == 2){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" +" }else if(remain_y == 1){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" }\n" +" \n" +" if(out_channel_block_idx+1 >= out_channel_blocks) {\n" +" return;\n" +" }\n" +" output_idx += output_shape.y;\n" +" if(remain_y >= 4){\n" +" WI_F(output,(int2)(output_idx,output_idy),out4);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n" +" WI_F(output,(int2)(output_idx,output_idy+2),out6);\n" +" WI_F(output,(int2)(output_idx,output_idy+3),out7);\n" +" }else if(remain_y == 3){\n" +" WI_F(output,(int2)(output_idx,output_idy),out4);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n" +" WI_F(output,(int2)(output_idx,output_idy+2),out6);\n" +" }else if(remain_y == 2){\n" +" WI_F(output,(int2)(output_idx,output_idy),out4);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n" +" }else if(remain_y == 1){\n" +" WI_F(output,(int2)(output_idx,output_idy),out4);\n" +" }\n" +"}\n" +"__kernel\n" +"#ifdef SET_ATTRIBUTE\n" +"__attribute__((work_group_size_hint(16,16,1)))\n" +"#endif\n" +"void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" +"#if QUANT_BIT == 8\n" +" __global const char *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#else\n" +" __global const uchar *kernel_ptr,\n" +" __global const float *dequantScaleOffset,\n" +"#endif\n" +"#ifdef BIAS\n" +" __read_only image2d_t bias,\n" +"#endif\n" +" __write_only image2d_t output,\n" +" __private const int2 input_shape,\n" +" __private const int in_channel_block_length,\n" +" __private const int2 output_shape,\n" +" __private const int2 weights_shape,\n" +" __private const int2 stride_shape,\n" +" __private const int2 padding_shape,\n" +" __private const int2 dilation_shape,\n" +" __private const int out_width_blocks,\n" +" __private const int out_channel_blocks,\n" +" __private const int out_height_blocks\n" +" ,__private const int blockDim\n" +" ,__private const int inChannel\n" +") {\n" +" const int output_channel_width_idx=get_global_id(0);\n" +" const int output_batch_height_idx=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" +" const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n" +" const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n" +" const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n" +" const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n" +"#ifdef BIAS\n" +" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n" +"#else\n" +" FLOAT4 out0=(FLOAT4)0;\n" +"#endif\n" +" FLOAT4 out1=out0;\n" +" FLOAT4 out2=out0;\n" +" FLOAT4 out3=out0;\n" +" int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n" +" int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n" +" int in_height1=in_height0+stride_shape.x;\n" +" int in_height2=in_height1+stride_shape.x;\n" +" int in_height3=in_height2+stride_shape.x;\n" +" int weight_size=mul24(weights_shape.y,weights_shape.x);\n" +" \n" +" const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n" +" const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n" +" \n" +" FLOAT4 in0,in1,in2,in3;\n" +" FLOAT4 weights0,weights1,weights2,weights3;\n" +" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n" +" for (int in_channel_block_idx=0; in_channel_block_idx= input_shape.x));\n" +" int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n" +" int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n" +" int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n" +" for (int ix=0; ix= input_shape.y));\n" +" \n" +" in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n" +" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n" +" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n" +" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n" +" \n" +"#if QUANT_BIT == 8\n" +" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" +" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" +" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" +" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" weight_offset += 4;\n" +"#else\n" +" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" +" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" +" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" +" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" +" char4 charWeight0=(char4)(0,0,0,0);\n" +" char4 charWeight1=(char4)(0,0,0,0);\n" +" char4 charWeight2=(char4)(0,0,0,0);\n" +" char4 charWeight3=(char4)(0,0,0,0);\n" +" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" +" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" +" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" +" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" +" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" +" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" +" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" +" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" +" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" +" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" +" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" +" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" +" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" +" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" +" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" +" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" +" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" +" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" +" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" +" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" weight_offset += 4;\n" +"#endif\n" +" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" +" CALCULATE_OUTPUT(0);\n" +" CALCULATE_OUTPUT(1);\n" +" CALCULATE_OUTPUT(2);\n" +" CALCULATE_OUTPUT(3);\n" +" }\n" +" }\n" +" }\n" +"#ifdef RELU\n" +" out0=fmax(out0,(FLOAT4)0);\n" +" out1=fmax(out1,(FLOAT4)0);\n" +" out2=fmax(out2,(FLOAT4)0);\n" +" out3=fmax(out3,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" +" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" +" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" +" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n" +" const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n" +" int out_x_idx=out_width_block_idx;\n" +" int out_y_idx=out_height_block_idx << 2;\n" +" const int remain_y=output_shape.x-out_y_idx;\n" +" int output_idx=out_x_base+out_x_idx;\n" +" int output_idy=out_y_base+out_y_idx;\n" +" if(remain_y >= 4){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" +" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" +" WI_F(output,(int2)(output_idx,output_idy+3),out3);\n" +" }else if(remain_y == 3){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" +" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" +" }else if(remain_y == 2){\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" +" }else{\n" +" WI_F(output,(int2)(output_idx,output_idy),out0);\n" +" }\n" +"}\n" +; +} diff --git a/source/backend/opencl/execution/cl/conv_2d_mnn_cl.cpp b/source/backend/opencl/execution/cl/conv_2d_mnn_cl.cpp index 778e64fc..ad9117b2 100644 --- a/source/backend/opencl/execution/cl/conv_2d_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/conv_2d_mnn_cl.cpp @@ -13,15 +13,8 @@ const char* conv_2d = "#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" "#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" "#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"#define UNIT 4\n" -"#define MOD_NUM 15\n" -"#ifdef INPUT_CHANNEL_LEAVE\n" -" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3> 4)-8;\n" -" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n" -" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n" -" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n" -" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n" -" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n" -" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n" -" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)- 8;\n" -" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n" -" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n" -" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n" -" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n" -" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n" -" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n" -" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n" -" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(weights_width_base,weights+weight_offset);\n" " weights1=vload4(weights_width_base+1,weights+weight_offset);\n" " weights2=vload4(weights_width_base+2,weights+weight_offset);\n" @@ -253,7 +192,6 @@ const char* conv_2d = " weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_block_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_block_idx));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n" " in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n" " in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n" @@ -296,17 +234,11 @@ const char* conv_2d = " }\n" "}\n" "__kernel\n" -"#if SET_ATTRIBUTE\n" +"#ifdef SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" __global const char *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" __global const uchar *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" @@ -318,10 +250,6 @@ const char* conv_2d = " __private const int2 stride_shape,\n" " __private const int output_width_4,\n" " __private const int out_channel_blocks\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" ,__private const int blockDim\n" -" ,__private const int inChannel\n" -"#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" @@ -329,13 +257,8 @@ const char* conv_2d = " const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n" " const int output_width_block_idx=output_channel_width_idx % output_width_4;\n" " const int output_channel_idx=output_channel_block_idx << 1;\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" int weight_ic_offset=output_channel_block_idx*16;\n" -" int weight_oc_offset=out_channel_blocks*8;\n" -"#else\n" " int weight_ic_offset=output_channel_block_idx*32;\n" " int weight_oc_offset=out_channel_blocks*16;\n" -"#endif\n" " FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_idx,0));\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" @@ -377,16 +300,6 @@ const char* conv_2d = " int weight_offset=output_channel_idx*in_channel_block*4*4;\n" " int weight_offset1=weight_offset+in_channel_block*4*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" -" #else\n" -" FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" -" #endif\n" -" FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n" -" FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n" -" FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n" -" FLOAT4 weights3=CONVERT_FLOAT4(weightsInt80.scdef)*scale0+offset0;\n" -" FLOAT4 weights4=CONVERT_FLOAT4(weightsInt81.s0123)*scale1+offset1;\n" -" FLOAT4 weights5=CONVERT_FLOAT4(weightsInt81.s4567)*scale1+offset1;\n" -" FLOAT4 weights6=CONVERT_FLOAT4(weightsInt81.s89ab)*scale1+offset1;\n" -" FLOAT4 weights7=CONVERT_FLOAT4(weightsInt81.scdef)*scale1+offset1;\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" uchar16 charWeightsInt4=vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n" -" char4 charWeights0=(char4)(0,0,0,0);\n" -" char4 charWeights1=(char4)(0,0,0,0);\n" -" char4 charWeights2=(char4)(0,0,0,0);\n" -" char4 charWeights3=(char4)(0,0,0,0);\n" -" char4 charWeights4=(char4)(0,0,0,0);\n" -" char4 charWeights5=(char4)(0,0,0,0);\n" -" char4 charWeights6=(char4)(0,0,0,0);\n" -" char4 charWeights7=(char4)(0,0,0,0);\n" -" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n" -" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n" -" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n" -" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n" -" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n" -" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n" -" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n" -" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)-8;\n" -" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n" -" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n" -" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n" -" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n" -" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n" -" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n" -" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n" -" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n" -" charWeights4.x=(charWeightsInt4.s8 >> 4)-8;\n" -" charWeights4.y=(charWeightsInt4.s8 & MOD_NUM)-8;\n" -" charWeights4.z=(charWeightsInt4.s9 >> 4)-8;\n" -" charWeights4.w=(charWeightsInt4.s9 & MOD_NUM)-8;\n" -" charWeights5.x=(charWeightsInt4.sa >> 4)-8;\n" -" charWeights5.y=(charWeightsInt4.sa & MOD_NUM)-8;\n" -" charWeights5.z=(charWeightsInt4.sb >> 4)-8;\n" -" charWeights5.w=(charWeightsInt4.sb & MOD_NUM)-8;\n" -" charWeights6.x=(charWeightsInt4.sc >> 4)-8;\n" -" charWeights6.y=(charWeightsInt4.sc & MOD_NUM)-8;\n" -" charWeights6.z=(charWeightsInt4.sd >> 4)-8;\n" -" charWeights6.w=(charWeightsInt4.sd & MOD_NUM)-8;\n" -" charWeights7.x=(charWeightsInt4.se >> 4)-8;\n" -" charWeights7.y=(charWeightsInt4.se & MOD_NUM)-8;\n" -" charWeights7.z=(charWeightsInt4.sf >> 4)-8;\n" -" charWeights7.w=(charWeightsInt4.sf & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n" -" weights4=mad(CONVERT_FLOAT4(charWeights4),scale1,offset1);\n" -" weights5=mad(CONVERT_FLOAT4(charWeights5),scale1,offset1);\n" -" weights6=mad(CONVERT_FLOAT4(charWeights6),scale1,offset1);\n" -" weights7=mad(CONVERT_FLOAT4(charWeights7),scale1,offset1);\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(weights_width_base,weights+weight_offset);\n" " weights1=vload4(weights_width_base+1,weights+weight_offset);\n" " weights2=vload4(weights_width_base+2,weights+weight_offset);\n" @@ -486,8 +334,6 @@ const char* conv_2d = " weights6=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx+1));\n" " weights7=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx+1));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" @@ -558,17 +404,11 @@ const char* conv_2d = " }\n" "}\n" "__kernel\n" -"#if SET_ATTRIBUTE\n" +"#ifdef SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" __global const char *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" __global const uchar *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" @@ -587,10 +427,6 @@ const char* conv_2d = " __private const int out_width_blocks,\n" " __private const int out_channel_blocks,\n" " __private const int out_height_blocks\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" ,__private const int blockDim\n" -" ,__private const int inChannel\n" -"#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" @@ -625,22 +461,14 @@ const char* conv_2d = " const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n" " const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0),weights_shape.y);\n" "#endif\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n" "#endif\n" " FLOAT4 in0,in1,in2,in3;\n" " FLOAT4 weights0,weights1,weights2,weights3;\n" " for (int in_channel_block_idx=0; in_channel_block_idx> 4)-8;\n" -" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" -" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" -" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" -" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" -" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" -" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" -" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" -" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" -" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" -" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" -" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" -" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" -" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" -" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" -" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" weight_offset += 4;\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" @@ -707,7 +495,6 @@ const char* conv_2d = " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" @@ -718,47 +505,7 @@ const char* conv_2d = " in1=in2;\n" " in2=in3;\n" " READ_INPUT_IMAGE(3,w);\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" -" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" -" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" -" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" weight_offset += 4;\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" -" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" -" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" -" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" -" char4 charWeight0=(char4)(0,0,0,0);\n" -" char4 charWeight1=(char4)(0,0,0,0);\n" -" char4 charWeight2=(char4)(0,0,0,0);\n" -" char4 charWeight3=(char4)(0,0,0,0);\n" -" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" -" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" -" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" -" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" -" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" -" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" -" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" -" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" -" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" -" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" -" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" -" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" -" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" -" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" -" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" -" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" weight_offset += 4;\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" @@ -770,7 +517,6 @@ const char* conv_2d = " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" @@ -783,47 +529,7 @@ const char* conv_2d = " READ_INPUT_IMAGE(1,input_width_base);\n" " READ_INPUT_IMAGE(2,input_width_base);\n" " READ_INPUT_IMAGE(3,input_width_base);\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" -" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" -" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" -" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" weight_offset += 4;\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" -" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" -" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" -" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" -" char4 charWeight0=(char4)(0,0,0,0);\n" -" char4 charWeight1=(char4)(0,0,0,0);\n" -" char4 charWeight2=(char4)(0,0,0,0);\n" -" char4 charWeight3=(char4)(0,0,0,0);\n" -" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" -" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" -" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" -" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" -" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" -" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" -" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" -" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" -" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" -" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" -" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" -" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" -" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" -" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" -" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" -" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" weight_offset += 4;\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" @@ -835,7 +541,6 @@ const char* conv_2d = " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx)); \n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" @@ -877,17 +582,11 @@ const char* conv_2d = " }\n" "}\n" "__kernel\n" -"#if SET_ATTRIBUTE\n" +"#ifdef SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" __global const char *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" __global const uchar *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" @@ -906,10 +605,6 @@ const char* conv_2d = " __private const int out_width_blocks,\n" " __private const int out_channel_blocks,\n" " __private const int out_height_blocks\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" ,__private const int blockDim\n" -" ,__private const int inChannel\n" -"#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" @@ -931,7 +626,7 @@ const char* conv_2d = " FLOAT4 out5=out4;\n" " FLOAT4 out6=out4;\n" " FLOAT4 out7=out4;\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " const int weight_oc_offset=weights_shape.x*weights_shape.y*4;\n" " const int weight_ic_offset=out_channel_blocks*weight_oc_offset;\n" "#endif\n" @@ -948,18 +643,8 @@ const char* conv_2d = " FLOAT4 in0,in1,in2,in3;\n" " FLOAT4 weights0,weights1,weights2,weights3,weights4,weights5,weights6,weights7;\n" " for (int in_channel_block_idx=0; in_channel_block_idx= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" -" charWeight1=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" -" charWeight2=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" -" charWeight3=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" -" \n" -" #else\n" -" charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" -" charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" -" charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" -" charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" -" #endif\n" -" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" -" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" -" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" -" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n" -" weight_offset += 4;\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" -" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset/2);\n" -" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*2/2);\n" -" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*3/2);\n" -" char4 charWeight0=(char4)(0,0,0,0);\n" -" char4 charWeight1=(char4)(0,0,0,0);\n" -" char4 charWeight2=(char4)(0,0,0,0);\n" -" char4 charWeight3=(char4)(0,0,0,0);\n" -" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" -" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" -" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" -" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" -" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" -" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" -" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" -" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" -" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" -" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" -" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" -" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n" -" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" -" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" -" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" -" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" charWeightInt40=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" -" charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n" -" charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n" -" charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n" -" charWeight0=(char4)(0,0,0,0);\n" -" charWeight1=(char4)(0,0,0,0);\n" -" charWeight2=(char4)(0,0,0,0);\n" -" charWeight3=(char4)(0,0,0,0);\n" -" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" -" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" -" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" -" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" -" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" -" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" -" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" -" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n" -" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" -" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" -" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" -" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n" -" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" -" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" -" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" -" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" -" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" -" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" -" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" -" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n" -" weight_offset += 4;\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_ic_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_ic_offset*2);\n" @@ -1088,8 +689,6 @@ const char* conv_2d = " weights6=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weight_size+weights_y_idx));\n" " weights7=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weight_size+weights_y_idx++));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n" " \n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" @@ -1167,17 +766,11 @@ const char* conv_2d = " }\n" "}\n" "__kernel\n" -"#if SET_ATTRIBUTE\n" +"#ifdef SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" __global const char *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" __global const uchar *kernel_ptr,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" @@ -1196,10 +789,6 @@ const char* conv_2d = " __private const int out_width_blocks,\n" " __private const int out_channel_blocks,\n" " __private const int out_height_blocks\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" ,__private const int blockDim\n" -" ,__private const int inChannel\n" -"#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" @@ -1228,18 +817,12 @@ const char* conv_2d = " \n" " FLOAT4 in0,in1,in2,in3;\n" " FLOAT4 weights0,weights1,weights2,weights3;\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n" "#endif\n" " for (int in_channel_block_idx=0; in_channel_block_idx> 4)-8;\n" -" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" -" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" -" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" -" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" -" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" -" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" -" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" -" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" -" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" -" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" -" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" -" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" -" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" -" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" -" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" -" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" -" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" -" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" -" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" -" weight_offset += 4;\n" -"#elif (defined USE_BUFFER)\n" +"#ifdef USE_BUFFER\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" @@ -1310,7 +853,6 @@ const char* conv_2d = " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" -" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" diff --git a/source/backend/opencl/execution/cl/depthwise_conv2d.cl b/source/backend/opencl/execution/cl/depthwise_conv2d.cl index 506aa7de..3d87d512 100644 --- a/source/backend/opencl/execution/cl/depthwise_conv2d.cl +++ b/source/backend/opencl/execution/cl/depthwise_conv2d.cl @@ -23,7 +23,7 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | } __kernel -#if SET_ATTRIBUTE +#ifdef SET_ATTRIBUTE __attribute__((work_group_size_hint(16, 16, 1))) #endif void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter, @@ -130,7 +130,7 @@ void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_ } __kernel -#if SET_ATTRIBUTE +#ifdef SET_ATTRIBUTE __attribute__((work_group_size_hint(16, 16, 1))) #endif void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter, diff --git a/source/backend/opencl/execution/cl/depthwise_conv2d_mnn_cl.cpp b/source/backend/opencl/execution/cl/depthwise_conv2d_mnn_cl.cpp index 3f5cb3b1..c74cbb71 100644 --- a/source/backend/opencl/execution/cl/depthwise_conv2d_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/depthwise_conv2d_mnn_cl.cpp @@ -10,7 +10,7 @@ const char* depthwise_conv2d = "__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" "#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" "__kernel\n" -"#if SET_ATTRIBUTE\n" +"#ifdef SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,__read_only image2d_t filter,\n" @@ -102,7 +102,7 @@ const char* depthwise_conv2d = " }\n" "}\n" "__kernel\n" -"#if SET_ATTRIBUTE\n" +"#ifdef SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,__read_only image2d_t filter,\n" diff --git a/source/backend/opencl/execution/cl/gemm.cl b/source/backend/opencl/execution/cl/gemm.cl index eb2d9e19..78a2cea4 100644 --- a/source/backend/opencl/execution/cl/gemm.cl +++ b/source/backend/opencl/execution/cl/gemm.cl @@ -289,82 +289,24 @@ __kernel void gemmWinogradW2(__read_only image2d_t uInput, __read_only image2d_t __kernel void gemm_conv(GLOBAL_SIZE_DIM2 __read_only image2d_t input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *weight, - __global const float *dequantScaleOffset, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *weight, - __global const float *dequantScaleOffset, -#else __global const FLOAT *weight, -#endif __read_only image2d_t bias, __write_only image2d_t output, __private const int dstChannelC4, __private const int srcChannelC4, __private const int batch -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - ,__private const int blockDim - ,__private const int srcChannel -#endif ) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); FLOAT4 out = RI_F(bias, SAMPLER, (int2)(pos.x, 0)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) int weight_offset = pos.x * 16; int weight_oc_offset = dstChannelC4 * 16; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = pos.x * 8; - int weight_oc_offset = dstChannelC4 * 8; -#else - int weight_offset = pos.x * 16; - int weight_oc_offset = dstChannelC4 * 16; -#endif for (int k = 0; k < srcChannelC4; ++k) { -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - int kindex = (k * 4) / blockDim * dstChannelC4 * 8; - COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex)); - COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, - ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, - ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, - ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6); - COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, - ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, - ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, - ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7); -#endif FLOAT4 in = RI_F(input, SAMPLER, (int2)(k, pos.y)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights = 0; - charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8; - charWeights.s1 = (charWeightsInt4.s0 & 15) - 8; - charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8; - charWeights.s3 = (charWeightsInt4.s1 & 15) - 8; - charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8; - charWeights.s5 = (charWeightsInt4.s2 & 15) - 8; - charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8; - charWeights.s7 = (charWeightsInt4.s3 & 15) - 8; - charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8; - charWeights.s9 = (charWeightsInt4.s4 & 15) - 8; - charWeights.sa = (charWeightsInt4.s5 >> 4) - 8; - charWeights.sb = (charWeightsInt4.s5 & 15) - 8; - charWeights.sc = (charWeightsInt4.s6 >> 4) - 8; - charWeights.sd = (charWeightsInt4.s6 & 15) - 8; - charWeights.se = (charWeightsInt4.s7 >> 4) - 8; - charWeights.sf = (charWeightsInt4.s7 & 15) - 8; - FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset; - -#else FLOAT16 weights = vload16(0, weight + weight_offset + k * weight_oc_offset); -#endif - PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef); out = mad((FLOAT4)in.x, (FLOAT4)weights.s0123, out); out = mad((FLOAT4)in.y, (FLOAT4)weights.s4567, out); @@ -385,24 +327,12 @@ __kernel void gemm_conv(GLOBAL_SIZE_DIM2 __kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2 __read_only image2d_t input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *weight, - __global const float *dequantScaleOffset, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *weight, - __global const float *dequantScaleOffset, -#else __global const FLOAT *weight, -#endif __read_only image2d_t bias, __write_only image2d_t output, __private const int dstChannelC4, __private const int srcChannelC4, __private const int batch -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - ,__private const int blockDim - ,__private const int srcChannel -#endif ) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -412,58 +342,13 @@ __kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2 FLOAT4 bias0 = RI_F(bias, SAMPLER, (int2)(pos.x, 0)); FLOAT4 out0 = bias0, out1 = bias0; -#if (defined USE_LOW_BIT_WEIGHT_INT8) int weight_offset = pos.x * 16; int weight_oc_offset = dstChannelC4 * 16; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = pos.x * 8; - int weight_oc_offset = dstChannelC4 * 8; -#else - int weight_offset = pos.x * 16; - int weight_oc_offset = dstChannelC4 * 16; -#endif for (int k = 0; k < srcChannelC4; ++k) { -#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) - int kindex = (k * 4) / blockDim * dstChannelC4 * 8; - COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex)); - COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, - ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, - ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, - ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6); - COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, - ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, - ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, - ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7); -#endif FLOAT4 in0 = RI_F(input, SAMPLER, (int2)(k, pos_y)); FLOAT4 in1 = RI_F(input, SAMPLER, (int2)(k, pos_y + 1)); -#if (defined USE_LOW_BIT_WEIGHT_INT8) - FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights = 0; - charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8; - charWeights.s1 = (charWeightsInt4.s0 & 15) - 8; - charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8; - charWeights.s3 = (charWeightsInt4.s1 & 15) - 8; - charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8; - charWeights.s5 = (charWeightsInt4.s2 & 15) - 8; - charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8; - charWeights.s7 = (charWeightsInt4.s3 & 15) - 8; - charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8; - charWeights.s9 = (charWeightsInt4.s4 & 15) - 8; - charWeights.sa = (charWeightsInt4.s5 >> 4) - 8; - charWeights.sb = (charWeightsInt4.s5 & 15) - 8; - charWeights.sc = (charWeightsInt4.s6 >> 4) - 8; - charWeights.sd = (charWeightsInt4.s6 & 15) - 8; - charWeights.se = (charWeightsInt4.s7 >> 4) - 8; - charWeights.sf = (charWeightsInt4.s7 & 15) - 8; - FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset; -#else FLOAT16 weights = vload16(0, weight + weight_offset + k * weight_oc_offset); -#endif - PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef); out0 = mad((FLOAT4)in0.x, (FLOAT4)weights.s0123, out0); out0 = mad((FLOAT4)in0.y, (FLOAT4)weights.s4567, out0); diff --git a/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl b/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl index 72b25eb1..b72ccf55 100644 --- a/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl +++ b/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl @@ -18,9 +18,9 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2 #ifdef USE_IMAGE __read_only image2d_t weight, #else - #if (defined USE_LOW_BIT_WEIGHT_INT8) + #if QUANT_BIT == 8 __global const char *weight, - #elif (defined USE_LOW_BIT_WEIGHT_INT4) + #else __global const uchar *weight, #endif #endif @@ -37,7 +37,7 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2 UNIFORM_BOUNDRY_CHECK(x, y); -#if (defined USE_LOW_BIT_WEIGHT_INT4) +#if QUANT_BIT == 4 const int ic = x << 2; const int oc = y << 3; const int output_offset = ic * outputChannelAlign + oc; @@ -90,7 +90,7 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2 vstore8(CONVERT_FLOAT8(weights1), 0, output+output_offset+outputChannelAlign); vstore8(CONVERT_FLOAT8(weights2), 0, output+output_offset+2*outputChannelAlign); vstore8(CONVERT_FLOAT8(weights3), 0, output+output_offset+3*outputChannelAlign); -#else +#elif QUANT_BIT == 8 const int ic = x << 1; const int oc = y << 3; const int output_offset = ic * outputChannelAlign + oc; @@ -125,6 +125,83 @@ __kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2 #endif } +__kernel void gemm_c4nhw4_to_nhwc(GLOBAL_SIZE_DIM2 +__global const FLOAT* input, +__global FLOAT* output, +__private const int bhw, +__private const int channel, +__private const int channelAlign +){ + const int x = get_global_id(0); //b/4 + const int y = get_global_id(1); //c/4 + UNIFORM_BOUNDRY_CHECK(x, y); + const int out_b_idx = x << 2; + const int out_c_idx = y << 2; + const int bhw4 = bhw << 2; + const int input_offset = y * bhw4 + out_b_idx * 4; + FLOAT4 in0, in1, in2, in3; + if(out_c_idx + 3 < channel && out_b_idx + 3 < bhw){ + in0 = vload4(0, input + input_offset); + in1 = vload4(0, input + input_offset + 4); + in2 = vload4(0, input + input_offset + 8); + in3 = vload4(0, input + input_offset + 12); + } else{ + if(out_c_idx + 3 < channel){ + in0 = vload4(0, input + input_offset); + in1 = out_b_idx + 1 < bhw ? vload4(0, input + input_offset + 4) : 0; + in2 = out_b_idx + 2 < bhw ? vload4(0, input + input_offset + 8) : 0; + in3 = out_b_idx + 3 < bhw ? vload4(0, input + input_offset + 12) : 0; + } else if(out_c_idx + 1 == channel){ + in0 = (FLOAT4)(input[input_offset], 0, 0, 0); + in1 = out_b_idx + 1 < bhw ? (FLOAT4)(input[input_offset + 4], 0, 0, 0) : 0; + in2 = out_b_idx + 2 < bhw ? (FLOAT4)(input[input_offset + 8], 0, 0, 0) : 0; + in3 = out_b_idx + 3 < bhw ? (FLOAT4)(input[input_offset + 12], 0, 0, 0) : 0; + } else if(out_c_idx + 2 == channel){ + in0 = (FLOAT4)(input[input_offset], input[input_offset + 1], 0, 0); + in1 = out_b_idx + 1 < bhw ? (FLOAT4)(input[input_offset + 4], input[input_offset + 5], 0, 0) : 0; + in2 = out_b_idx + 2 < bhw ? (FLOAT4)(input[input_offset + 8], input[input_offset + 9], 0, 0) : 0; + in3 = out_b_idx + 3 < bhw ? (FLOAT4)(input[input_offset + 12], input[input_offset + 13], 0, 0) : 0; + } else if(out_c_idx + 3 == channel){ + in0 = (FLOAT4)(input[input_offset], input[input_offset + 1], input[input_offset + 2], 0); + in1 = out_b_idx + 1 < bhw ? (FLOAT4)(input[input_offset + 4], input[input_offset + 5], input[input_offset + 6], 0) : 0; + in2 = out_b_idx + 2 < bhw ? (FLOAT4)(input[input_offset + 8], input[input_offset + 9], input[input_offset + 10], 0) : 0; + in3 = out_b_idx + 3 < bhw ? (FLOAT4)(input[input_offset + 12], input[input_offset + 13], input[input_offset + 14], 0) : 0; + } + } + int out_offset = out_b_idx * channelAlign + out_c_idx; + vstore4(in0, 0, output + out_offset); + vstore4(in1, 0, output + out_offset + channelAlign); + vstore4(in2, 0, output + out_offset + channelAlign + channelAlign); + vstore4(in3, 0, output + out_offset + channelAlign + channelAlign + channelAlign); +} + +__kernel void gemm_nhwc_to_c4nhw4(GLOBAL_SIZE_DIM2 +__global const FLOAT* input, +__global FLOAT* output, +__private const int bhw, +__private const int channelAlign +){ + const int x = get_global_id(0); //b/4 + const int y = get_global_id(1); //c/4 + UNIFORM_BOUNDRY_CHECK(x, y); + const int out_b_idx = x << 2; + const int out_c_idx = y << 2; + const int bhw4 = bhw << 2; + const int input_offset = out_b_idx * channelAlign + out_c_idx; + FLOAT4 in0 = vload4(0, input + input_offset); + FLOAT4 in1 = vload4(0, input + input_offset + channelAlign); + FLOAT4 in2 = vload4(0, input + input_offset + channelAlign + channelAlign); + FLOAT4 in3 = vload4(0, input + input_offset + channelAlign + channelAlign + channelAlign); + int out_offset = y * bhw4 + out_b_idx * 4; + vstore4(in0, 0, output + out_offset); + if(out_b_idx + 1 >= bhw) return; + vstore4(in1, 0, output + out_offset + 4); + if(out_b_idx + 2 >= bhw) return; + vstore4(in2, 0, output + out_offset + 8); + if(out_b_idx + 3 >= bhw) return; + vstore4(in3, 0, output + out_offset + 12); +} + #define UCHAR4_TO_FLOAT8(b, scale, offset) \ wei.s0 = (COMPUTE_FLOAT)((b.s0 >> 4) - 8); \ wei.s1 = (COMPUTE_FLOAT)((b.s0 & 15) - 8); \ @@ -440,7 +517,6 @@ __kernel void gemm_b4_c8_int4_buf(GLOBAL_SIZE_DIM2 #endif } - __kernel void gemm_b4_c8_int8_buf(GLOBAL_SIZE_DIM2 __global const FLOAT* input, #ifdef USE_IMAGE diff --git a/source/backend/opencl/execution/cl/gemm_conv1x1_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/gemm_conv1x1_buf_mnn_cl.cpp index 5d5702c5..e1fed490 100644 --- a/source/backend/opencl/execution/cl/gemm_conv1x1_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/gemm_conv1x1_buf_mnn_cl.cpp @@ -13,9 +13,9 @@ const char* gemm_conv1x1_buf = " #ifdef USE_IMAGE\n" " __read_only image2d_t weight,\n" " #else\n" -" #if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +" #if QUANT_BIT == 8\n" " __global const char *weight,\n" -" #elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +" #else\n" " __global const uchar *weight,\n" " #endif\n" " #endif\n" @@ -31,7 +31,7 @@ const char* gemm_conv1x1_buf = " const int y=get_global_id(1); //oc\n" " UNIFORM_BOUNDRY_CHECK(x,y);\n" " \n" -"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n" +"#if QUANT_BIT == 4\n" " const int ic=x << 2;\n" " const int oc=y << 3;\n" " const int output_offset=ic*outputChannelAlign+oc;\n" @@ -83,7 +83,7 @@ const char* gemm_conv1x1_buf = " vstore8(CONVERT_FLOAT8(weights1),0,output+output_offset+outputChannelAlign);\n" " vstore8(CONVERT_FLOAT8(weights2),0,output+output_offset+2*outputChannelAlign);\n" " vstore8(CONVERT_FLOAT8(weights3),0,output+output_offset+3*outputChannelAlign);\n" -"#else\n" +"#elif QUANT_BIT == 8\n" " const int ic=x << 1;\n" " const int oc=y << 3;\n" " const int output_offset=ic*outputChannelAlign+oc;\n" @@ -117,6 +117,81 @@ const char* gemm_conv1x1_buf = " vstore8(CONVERT_FLOAT8(weights1),0,output+output_offset+outputChannelAlign);\n" " #endif\n" "}\n" +"__kernel void gemm_c4nhw4_to_nhwc(GLOBAL_SIZE_DIM2\n" +"__global const FLOAT* input,\n" +"__global FLOAT* output,\n" +"__private const int bhw,\n" +"__private const int channel,\n" +"__private const int channelAlign\n" +"){\n" +" const int x=get_global_id(0); //b/4\n" +" const int y=get_global_id(1); //c/4\n" +" UNIFORM_BOUNDRY_CHECK(x,y);\n" +" const int out_b_idx=x << 2;\n" +" const int out_c_idx=y << 2;\n" +" const int bhw4=bhw << 2;\n" +" const int input_offset=y*bhw4+out_b_idx*4;\n" +" FLOAT4 in0,in1,in2,in3;\n" +" if(out_c_idx+3= bhw) return;\n" +" vstore4(in1,0,output+out_offset+4);\n" +" if(out_b_idx+2 >= bhw) return;\n" +" vstore4(in2,0,output+out_offset+8);\n" +" if(out_b_idx+3 >= bhw) return;\n" +" vstore4(in3,0,output+out_offset+12);\n" +"}\n" "#define UCHAR4_TO_FLOAT8(b, scale, offset) "" wei.s0 = (COMPUTE_FLOAT)((b.s0 >> 4) - 8); "" wei.s1 = (COMPUTE_FLOAT)((b.s0 & 15) - 8); "" wei.s2 = (COMPUTE_FLOAT)((b.s1 >> 4) - 8); "" wei.s3 = (COMPUTE_FLOAT)((b.s1 & 15) - 8); "" wei.s4 = (COMPUTE_FLOAT)((b.s2 >> 4) - 8); "" wei.s5 = (COMPUTE_FLOAT)((b.s2 & 15) - 8); "" wei.s6 = (COMPUTE_FLOAT)((b.s3 >> 4) - 8); "" wei.s7 = (COMPUTE_FLOAT)((b.s3 & 15) - 8); "" wei=wei*scale+offset;\n" "__kernel void gemm_b4_c8_int4_buf(GLOBAL_SIZE_DIM2\n" " __global const FLOAT* input,\n" diff --git a/source/backend/opencl/execution/cl/gemm_int.cl b/source/backend/opencl/execution/cl/gemm_int.cl new file mode 100644 index 00000000..2b563819 --- /dev/null +++ b/source/backend/opencl/execution/cl/gemm_int.cl @@ -0,0 +1,203 @@ +#ifdef MNN_SUPPORT_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + +#define GLOBAL_SIZE_DIM2 \ + __private int global_size_dim0, __private int global_size_dim1, + +#define UNIFORM_BOUNDRY_CHECK(index0, index1) \ + if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { \ + return; \ + } + +__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +#ifdef INPUT_CHANNEL_LEAVE + #define PADZEROSVEC(k, channel, data0, data1, data2, data3) \ + data0 = (k << 2) < channel ? data0 : 0; \ + data1 = (k << 2) + 1 < channel ? data1 : 0; \ + data2 = (k << 2) + 2 < channel ? data2 : 0; \ + data3 = (k << 2) + 3 < channel ? data3 : 0; +#else + #define PADZEROSVEC(k, channel, data0, data1, data2, data3) +#endif + +__kernel void gemm_conv(GLOBAL_SIZE_DIM2 + __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *weight, + __global const float *dequantScaleOffset, +#else + __global const uchar *weight, + __global const float *dequantScaleOffset, +#endif + __read_only image2d_t bias, + __write_only image2d_t output, + __private const int dstChannelC4, + __private const int srcChannelC4, + __private const int batch + ,__private const int blockDim + ,__private const int srcChannel +) { + int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b + UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); + + FLOAT4 out = RI_F(bias, SAMPLER, (int2)(pos.x, 0)); + +#if QUANT_BIT == 8 + int weight_offset = pos.x * 16; + int weight_oc_offset = dstChannelC4 * 16; +#else + int weight_offset = pos.x * 8; + int weight_oc_offset = dstChannelC4 * 8; +#endif + + for (int k = 0; k < srcChannelC4; ++k) { + int kindex = (k * 4) / blockDim * dstChannelC4 * 8; + COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex)); + COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, + ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, + ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, + ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6); + COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, + ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, + ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, + ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7); + FLOAT4 in = RI_F(input, SAMPLER, (int2)(k, pos.y)); +#if QUANT_BIT == 8 + FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset; +#else + uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); + char16 charWeights = 0; + charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8; + charWeights.s1 = (charWeightsInt4.s0 & 15) - 8; + charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8; + charWeights.s3 = (charWeightsInt4.s1 & 15) - 8; + charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8; + charWeights.s5 = (charWeightsInt4.s2 & 15) - 8; + charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8; + charWeights.s7 = (charWeightsInt4.s3 & 15) - 8; + charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8; + charWeights.s9 = (charWeightsInt4.s4 & 15) - 8; + charWeights.sa = (charWeightsInt4.s5 >> 4) - 8; + charWeights.sb = (charWeightsInt4.s5 & 15) - 8; + charWeights.sc = (charWeightsInt4.s6 >> 4) - 8; + charWeights.sd = (charWeightsInt4.s6 & 15) - 8; + charWeights.se = (charWeightsInt4.s7 >> 4) - 8; + charWeights.sf = (charWeightsInt4.s7 & 15) - 8; + FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset; +#endif + PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef); + + out = mad((FLOAT4)in.x, (FLOAT4)weights.s0123, out); + out = mad((FLOAT4)in.y, (FLOAT4)weights.s4567, out); + out = mad((FLOAT4)in.z, (FLOAT4)weights.s89ab, out); + out = mad((FLOAT4)in.w, (FLOAT4)weights.scdef, out); + } + +#ifdef RELU + out = fmax(out, (FLOAT4)0); +#endif + +#ifdef RELU6 + out = clamp(out, (FLOAT4)0, (FLOAT4)6); +#endif + + WI_F(output, (int2)(pos.x, pos.y), out); +} + +__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2 + __read_only image2d_t input, +#if QUANT_BIT == 8 + __global const char *weight, + __global const float *dequantScaleOffset, +#else + __global const uchar *weight, + __global const float *dequantScaleOffset, +#endif + __read_only image2d_t bias, + __write_only image2d_t output, + __private const int dstChannelC4, + __private const int srcChannelC4, + __private const int batch + ,__private const int blockDim + ,__private const int srcChannel +) { + int2 pos = (int2)(get_global_id(0), get_global_id(1)); //cout/4, b + UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); + int pos_x = pos.x << 2; + int pos_y = pos.y << 1; + + FLOAT4 bias0 = RI_F(bias, SAMPLER, (int2)(pos.x, 0)); + FLOAT4 out0 = bias0, out1 = bias0; + +#if QUANT_BIT == 8 + int weight_offset = pos.x * 16; + int weight_oc_offset = dstChannelC4 * 16; +#else + int weight_offset = pos.x * 8; + int weight_oc_offset = dstChannelC4 * 8; +#endif + + for (int k = 0; k < srcChannelC4; ++k) { + int kindex = (k * 4) / blockDim * dstChannelC4 * 8; + COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(pos.x, dequantScaleOffset + kindex)); + COMPUTE_FLOAT16 scale = (COMPUTE_FLOAT16)(ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, + ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, + ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6, + ScaleOffset.s0, ScaleOffset.s2, ScaleOffset.s4, ScaleOffset.s6); + COMPUTE_FLOAT16 offset = (COMPUTE_FLOAT16)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, + ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, + ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7, + ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7); + FLOAT4 in0 = RI_F(input, SAMPLER, (int2)(k, pos_y)); + FLOAT4 in1 = RI_F(input, SAMPLER, (int2)(k, pos_y + 1)); +#if QUANT_BIT == 8 + FLOAT16 weights = CONVERT_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * scale + offset; +#else + uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); + char16 charWeights = 0; + charWeights.s0 = (charWeightsInt4.s0 >> 4) - 8; + charWeights.s1 = (charWeightsInt4.s0 & 15) - 8; + charWeights.s2 = (charWeightsInt4.s1 >> 4) - 8; + charWeights.s3 = (charWeightsInt4.s1 & 15) - 8; + charWeights.s4 = (charWeightsInt4.s2 >> 4) - 8; + charWeights.s5 = (charWeightsInt4.s2 & 15) - 8; + charWeights.s6 = (charWeightsInt4.s3 >> 4) - 8; + charWeights.s7 = (charWeightsInt4.s3 & 15) - 8; + charWeights.s8 = (charWeightsInt4.s4 >> 4) - 8; + charWeights.s9 = (charWeightsInt4.s4 & 15) - 8; + charWeights.sa = (charWeightsInt4.s5 >> 4) - 8; + charWeights.sb = (charWeightsInt4.s5 & 15) - 8; + charWeights.sc = (charWeightsInt4.s6 >> 4) - 8; + charWeights.sd = (charWeightsInt4.s6 & 15) - 8; + charWeights.se = (charWeightsInt4.s7 >> 4) - 8; + charWeights.sf = (charWeightsInt4.s7 & 15) - 8; + FLOAT16 weights = CONVERT_FLOAT16(charWeights) * scale + offset; +#endif + PADZEROSVEC(k, srcChannel, weights.s0123, weights.s4567, weights.s89ab, weights.scdef); + + out0 = mad((FLOAT4)in0.x, (FLOAT4)weights.s0123, out0); + out0 = mad((FLOAT4)in0.y, (FLOAT4)weights.s4567, out0); + out0 = mad((FLOAT4)in0.z, (FLOAT4)weights.s89ab, out0); + out0 = mad((FLOAT4)in0.w, (FLOAT4)weights.scdef, out0); + + out1 = mad((FLOAT4)in1.x, (FLOAT4)weights.s0123, out1); + out1 = mad((FLOAT4)in1.y, (FLOAT4)weights.s4567, out1); + out1 = mad((FLOAT4)in1.z, (FLOAT4)weights.s89ab, out1); + out1 = mad((FLOAT4)in1.w, (FLOAT4)weights.scdef, out1); + } +#ifdef RELU + out0 = fmax(out0, (FLOAT4)0); + out1 = fmax(out1, (FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); + out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); +#endif + + WI_F(output, (int2)(pos.x, pos_y), out0); + if(pos_y + 1 < batch) + WI_F(output, (int2)(pos.x, pos_y + 1), out1); +} diff --git a/source/backend/opencl/execution/cl/gemm_int_mnn_cl.cpp b/source/backend/opencl/execution/cl/gemm_int_mnn_cl.cpp new file mode 100644 index 00000000..429be9a5 --- /dev/null +++ b/source/backend/opencl/execution/cl/gemm_int_mnn_cl.cpp @@ -0,0 +1,185 @@ +#include "opencl_source_map.hpp" +namespace MNN { +const char* gemm_int = +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n" +"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n" +"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" +"#ifdef INPUT_CHANNEL_LEAVE\n" +" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3> 4)-8;\n" +" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n" +" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n" +" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n" +" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n" +" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n" +" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n" +" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n" +" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n" +" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n" +" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n" +" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n" +" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n" +" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n" +" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n" +" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n" +" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n" +"#endif\n" +" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n" +" \n" +" out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n" +" out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n" +" out=mad((FLOAT4)in.z,(FLOAT4)weights.s89ab,out);\n" +" out=mad((FLOAT4)in.w,(FLOAT4)weights.scdef,out);\n" +" }\n" +" \n" +"#ifdef RELU\n" +" out=fmax(out,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out=clamp(out,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" WI_F(output,(int2)(pos.x,pos.y),out);\n" +"}\n" +"__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n" +" __read_only image2d_t input,\n" +"#if QUANT_BIT == 8\n" +" __global const char *weight,\n" +" __global const float *dequantScaleOffset,\n" +"#else\n" +" __global const uchar *weight,\n" +" __global const float *dequantScaleOffset,\n" +"#endif\n" +" __read_only image2d_t bias,\n" +" __write_only image2d_t output,\n" +" __private const int dstChannelC4,\n" +" __private const int srcChannelC4,\n" +" __private const int batch\n" +" ,__private const int blockDim\n" +" ,__private const int srcChannel\n" +") {\n" +" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n" +" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" +" int pos_x=pos.x << 2;\n" +" int pos_y=pos.y << 1;\n" +" FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n" +" FLOAT4 out0=bias0,out1=bias0;\n" +" \n" +"#if QUANT_BIT == 8\n" +" int weight_offset=pos.x*16;\n" +" int weight_oc_offset=dstChannelC4*16;\n" +"#else\n" +" int weight_offset=pos.x*8;\n" +" int weight_oc_offset=dstChannelC4*8;\n" +"#endif\n" +" for (int k=0; k> 4)-8;\n" +" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n" +" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n" +" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n" +" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n" +" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n" +" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n" +" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n" +" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n" +" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n" +" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n" +" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n" +" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n" +" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n" +" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n" +" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n" +" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n" +"#endif\n" +" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n" +" \n" +" out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n" +" out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n" +" out0=mad((FLOAT4)in0.z,(FLOAT4)weights.s89ab,out0);\n" +" out0=mad((FLOAT4)in0.w,(FLOAT4)weights.scdef,out0);\n" +" \n" +" out1=mad((FLOAT4)in1.x,(FLOAT4)weights.s0123,out1);\n" +" out1=mad((FLOAT4)in1.y,(FLOAT4)weights.s4567,out1);\n" +" out1=mad((FLOAT4)in1.z,(FLOAT4)weights.s89ab,out1);\n" +" out1=mad((FLOAT4)in1.w,(FLOAT4)weights.scdef,out1);\n" +" }\n" +"#ifdef RELU\n" +" out0=fmax(out0,(FLOAT4)0);\n" +" out1=fmax(out1,(FLOAT4)0);\n" +"#endif\n" +"#ifdef RELU6\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" +" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" +"#endif\n" +" WI_F(output,(int2)(pos.x,pos_y),out0);\n" +" if(pos_y+1> 4)-8;\n" -" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n" -" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n" -" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n" -" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n" -" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n" -" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n" -" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n" -" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n" -" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n" -" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n" -" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n" -" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n" -" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n" -" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n" -" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n" -" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n" -" \n" -"#else\n" " FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n" -"#endif\n" -" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n" " \n" " out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n" " out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n" @@ -338,24 +280,12 @@ const char* gemm = "}\n" "__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n" " __read_only image2d_t input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" __global const char *weight,\n" -" __global const float *dequantScaleOffset,\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" __global const uchar *weight,\n" -" __global const float *dequantScaleOffset,\n" -"#else\n" " __global const FLOAT *weight,\n" -"#endif\n" " __read_only image2d_t bias,\n" " __write_only image2d_t output,\n" " __private const int dstChannelC4,\n" " __private const int srcChannelC4,\n" " __private const int batch\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" ,__private const int blockDim\n" -" ,__private const int srcChannel\n" -"#endif\n" ") {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -364,57 +294,12 @@ const char* gemm = " FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n" " FLOAT4 out0=bias0,out1=bias0;\n" " \n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " int weight_offset=pos.x*16;\n" " int weight_oc_offset=dstChannelC4*16;\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" int weight_offset=pos.x*8;\n" -" int weight_oc_offset=dstChannelC4*8;\n" -"#else\n" -" int weight_offset=pos.x*16;\n" -" int weight_oc_offset=dstChannelC4*16;\n" -"#endif\n" " for (int k=0; k> 4)-8;\n" -" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n" -" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n" -" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n" -" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n" -" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n" -" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n" -" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n" -" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n" -" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n" -" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n" -" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n" -" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n" -" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n" -" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n" -" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n" -" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n" -"#else\n" " FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n" -"#endif\n" -" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n" " \n" " out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n" " out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n" diff --git a/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl b/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl index 1e238d60..2b60dc5f 100644 --- a/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl +++ b/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl @@ -6,6 +6,9 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | #define GLOBAL_SIZE_DIM_2 \ __private int global_size_dim0, __private int global_size_dim1, +#define GLOBAL_SIZE_DIM_3 \ + __private int global_size_dim0, __private int global_size_dim1, __private int global_size_dim2, + #define UNIFORM_BOUNDRY_CHECK_2(index0, index1) \ if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { \ return; \ @@ -24,16 +27,22 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | #if WGS >= 8 -__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 +__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3 __global const FLOAT* input, #ifdef USE_IMAGE __read_only image2d_t weight, #else + #if QUANT_BIT == 8 + __global const char *weight, + #else __global const uchar *weight, + #endif #endif __global const FLOAT *dequantScaleOffset, __global const FLOAT *bias, __global FLOAT* output, + __private const int dstChannelAlign, + __private const int srcChannelAlign, __private const int dstChannelC4, __private const int srcChannelC4, __private const int srcChannel, @@ -43,32 +52,101 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 const int lid = get_local_id(0); const int oc = get_global_id(1); //oc/8 const int oc8 = oc << 3; -#if INPUT_CHANNEL_LEAVES_NUM != 0 - const int loop = max((srcChannel + 4 - 1) / 4 - 1, 0); -#else - const int loop = (srcChannel + 4 - 1) / 4; -#endif - __local COMPUTE_FLOAT8 sum[WGS]; - COMPUTE_FLOAT8 out0 = 0; -#ifndef USE_IMAGE - const int weight_offset = oc * srcChannelC4 * 16; -#endif +#if QUANT_BIT == 8 + #if INPUT_CHANNEL_LEAVES_NUM != 0 + const int loop = max((srcChannel + 2 - 1) / 2 - 1, 0); + #else + const int loop = (srcChannel + 2 - 1) / 2; + #endif + #ifndef USE_IMAGE + const int weight_offset = oc * srcChannelC4 * 32; + #endif +#else + #if INPUT_CHANNEL_LEAVES_NUM != 0 + const int loop = max((srcChannel + 4 - 1) / 4 - 1, 0); + #else + const int loop = (srcChannel + 4 - 1) / 4; + #endif + #ifndef USE_IMAGE + const int weight_offset = oc * srcChannelC4 * 16; + #endif +#endif + + COMPUTE_FLOAT8 out0 = 0; + int input_offset = 0, output_offset = oc8; + __local COMPUTE_FLOAT8 sum0[WGS]; +#ifdef COMPUTE_BATCH + const int out_b_idx = get_global_id(2) << 2; //b/4 + __local COMPUTE_FLOAT8 sum1[WGS]; + __local COMPUTE_FLOAT8 sum2[WGS]; + __local COMPUTE_FLOAT8 sum3[WGS]; + COMPUTE_FLOAT8 out1 = 0, out2 = 0, out3 = 0; + input_offset = out_b_idx * srcChannelAlign; + output_offset = oc8 + out_b_idx * dstChannelAlign; +#endif for(int j = lid; j < loop; j+=WGS){ + #if QUANT_BIT == 8 + int k2 = j << 1; + COMPUTE_FLOAT16 scale, offset; + { + #ifdef ASYMMETRIC + COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef); + scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace); + offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf); + #else + COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef); + scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset); + offset = 0; + #endif + } + COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + k2)); + #ifdef COMPUTE_BATCH + COMPUTE_FLOAT2 in1 = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + srcChannelAlign + k2)); + COMPUTE_FLOAT2 in2 = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + srcChannelAlign * 2 + k2)); + COMPUTE_FLOAT2 in3 = CONVERT_COMPUTE_FLOAT2(vload2(0, input + input_offset + srcChannelAlign * 3 + k2)); + #endif + #ifdef USE_IMAGE + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(j, oc)))) * scale + offset; + #else + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(j, weight + weight_offset)) * scale + offset; + #endif + { + out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0); + #ifdef COMPUTE_BATCH + out1 = mad((COMPUTE_FLOAT8)in1.s0, wei.s01234567, out1); + out2 = mad((COMPUTE_FLOAT8)in2.s0, wei.s01234567, out2); + out3 = mad((COMPUTE_FLOAT8)in3.s0, wei.s01234567, out3); + #endif + } + { + out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0); + #ifdef COMPUTE_BATCH + out1 = mad((COMPUTE_FLOAT8)in1.s1, wei.s89abcdef, out1); + out2 = mad((COMPUTE_FLOAT8)in2.s1, wei.s89abcdef, out2); + out3 = mad((COMPUTE_FLOAT8)in3.s1, wei.s89abcdef, out3); + #endif + } + #else int k4 = j << 2; -#ifdef ASYMMETRIC + #ifdef ASYMMETRIC COMPUTE_FLOAT8 scale, offset; { COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k4 / blockDim) * dstChannelC4 * 8)) / coef); scale = scaleOffset.s02468ace; offset = scaleOffset.s13579bdf; } -#else + #else COMPUTE_FLOAT8 scale = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k4 / blockDim) * dstChannelC4 * 4)) / coef); COMPUTE_FLOAT8 offset = 0; -#endif + #endif COMPUTE_FLOAT8 wei; - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4)); + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4 + input_offset)); + #ifdef COMPUTE_BATCH + COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + srcChannelAlign + k4)); + COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + srcChannelAlign * 2 + k4)); + COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + srcChannelAlign * 3 + k4)); + #endif #ifdef USE_IMAGE uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(j, oc))); #else @@ -77,39 +155,83 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 { UCHAR4_TO_CHAR8(charWeightsInt40.s0123, scale, offset); out0 = mad((COMPUTE_FLOAT8)in.s0, wei, out0); + #ifdef COMPUTE_BATCH + out1 = mad((COMPUTE_FLOAT8)in1.s0, wei, out1); + out2 = mad((COMPUTE_FLOAT8)in2.s0, wei, out2); + out3 = mad((COMPUTE_FLOAT8)in3.s0, wei, out3); + #endif } { UCHAR4_TO_CHAR8(charWeightsInt40.s4567, scale, offset); out0 = mad((COMPUTE_FLOAT8)in.s1, wei, out0); + #ifdef COMPUTE_BATCH + out1 = mad((COMPUTE_FLOAT8)in1.s1, wei, out1); + out2 = mad((COMPUTE_FLOAT8)in2.s1, wei, out2); + out3 = mad((COMPUTE_FLOAT8)in3.s1, wei, out3); + #endif } { UCHAR4_TO_CHAR8(charWeightsInt40.s89ab, scale, offset); out0 = mad((COMPUTE_FLOAT8)in.s2, wei, out0); + #ifdef COMPUTE_BATCH + out1 = mad((COMPUTE_FLOAT8)in1.s2, wei, out1); + out2 = mad((COMPUTE_FLOAT8)in2.s2, wei, out2); + out3 = mad((COMPUTE_FLOAT8)in3.s2, wei, out3); + #endif } { UCHAR4_TO_CHAR8(charWeightsInt40.scdef, scale, offset); out0 = mad((COMPUTE_FLOAT8)in.s3, wei, out0); + #ifdef COMPUTE_BATCH + out1 = mad((COMPUTE_FLOAT8)in1.s3, wei, out1); + out2 = mad((COMPUTE_FLOAT8)in2.s3, wei, out2); + out3 = mad((COMPUTE_FLOAT8)in3.s3, wei, out3); + #endif } + #endif } #if INPUT_CHANNEL_LEAVES_NUM != 0 { + #if QUANT_BIT == 8 + int k2 = loop << 1; + COMPUTE_FLOAT16 scale, offset; + { + #ifdef ASYMMETRIC + COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef); + scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace); + offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf); + #else + COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef); + scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset); + offset = 0; + #endif + } + #ifdef USE_IMAGE + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(loop, oc)))) * scale + offset; + #else + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(loop, weight + weight_offset)) * scale + offset; + #endif + { + out0 = mad((COMPUTE_FLOAT8)input[k2], wei.s01234567, out0); + } + #else int k4 = loop << 2; -#ifdef ASYMMETRIC + #ifdef ASYMMETRIC COMPUTE_FLOAT8 scale, offset; { COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k4 / blockDim) * dstChannelC4 * 8)) / coef); scale = scaleOffset.s02468ace; offset = scaleOffset.s13579bdf; } -#else + #else COMPUTE_FLOAT8 scale = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k4 / blockDim) * dstChannelC4 * 4)) / coef); COMPUTE_FLOAT8 offset = 0; -#endif + #endif COMPUTE_FLOAT8 wei; #ifdef USE_IMAGE uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(loop, oc))); #else - uchar16 charWeightsInt40 = vload16(j, weight + weight_offset); + uchar16 charWeightsInt40 = vload16(loop, weight + weight_offset); #endif { UCHAR4_TO_CHAR8(charWeightsInt40.s0123, scale, offset); @@ -127,17 +249,28 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 out0 = mad((COMPUTE_FLOAT8)input[k4 + 2], wei, out0); } #endif + #endif } #endif - sum[lid] = out0; + sum0[lid] = out0; + #ifdef COMPUTE_BATCH + sum1[lid] = out1; sum2[lid] = out2; sum3[lid] = out3; + #endif barrier(CLK_LOCAL_MEM_FENCE); for(int i = WGS/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; + if (lid < i){ + sum0[lid] = sum0[lid] + sum0[lid + i]; + #ifdef COMPUTE_BATCH + sum1[lid] = sum1[lid] + sum1[lid + i]; + sum2[lid] = sum2[lid] + sum2[lid + i]; + sum3[lid] = sum3[lid] + sum3[lid + i]; + #endif + } barrier(CLK_LOCAL_MEM_FENCE); } if(lid == 0){ - out0 = sum[0] + CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8)); + COMPUTE_FLOAT8 vBias = CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8)); + out0 = sum0[0] + vBias; #ifdef RELU out0 = fmax(out0, (COMPUTE_FLOAT8)0); #endif @@ -146,132 +279,43 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6); #endif #ifdef OUTPUT_CHANNEL_LEAVES - vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8); + vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + output_offset); if(oc8 + 4 < dstChannelC4 * 4) - vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + oc8 + 4); + vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + 4 + output_offset); #else - vstore8(CONVERT_FLOAT8(out0), 0, output + oc8); + vstore8(CONVERT_FLOAT8(out0), 0, output + output_offset); + #endif + #ifdef COMPUTE_BATCH + out1 = sum1[0] + vBias; out2 = sum2[0] + vBias; out3 = sum3[0] + vBias; + #ifdef RELU + out1 = fmax(out1, (COMPUTE_FLOAT8)0);out2 = fmax(out2, (COMPUTE_FLOAT8)0);out3 = fmax(out3, (COMPUTE_FLOAT8)0); + #endif + #ifdef RELU6 + out1 = clamp(out1, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);out2 = clamp(out2, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6);out3 = clamp(out3, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6); + #endif + vstore8(CONVERT_FLOAT8(out1), 0, output + output_offset + dstChannelAlign); + vstore8(CONVERT_FLOAT8(out2), 0, output + output_offset + dstChannelAlign + dstChannelAlign); + vstore8(CONVERT_FLOAT8(out3), 0, output + output_offset + dstChannelAlign + dstChannelAlign + dstChannelAlign); #endif } } - -__kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2 +#else +__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3 __global const FLOAT* input, #ifdef USE_IMAGE __read_only image2d_t weight, #else + #if QUANT_BIT == 8 __global const char *weight, -#endif - __global const FLOAT *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int srcChannel, - __private const int blockNum, - __private const int blockDim, - __private const float coef) { - const int lid = get_local_id(0); - const int oc = get_global_id(1); //oc/8 - const int oc8 = oc << 3; -#if INPUT_CHANNEL_LEAVES_NUM != 0 - const int loop = max((srcChannel + 2 - 1) / 2 - 1, 0); -#else - const int loop = (srcChannel + 2 - 1) / 2; -#endif - __local COMPUTE_FLOAT8 sum[WGS]; -#ifndef USE_IMAGE - const int weight_offset = oc * srcChannelC4 * 32; -#endif - COMPUTE_FLOAT8 out0 = 0; - for(int j = lid; j < loop; j+=WGS){ - int k2 = j << 1; - COMPUTE_FLOAT16 scale, offset; - { - #ifdef ASYMMETRIC - COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef); - scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace); - offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf); - #else - COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef); - scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset); - offset = 0; - #endif - } - COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + k2)); - #ifdef USE_IMAGE - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(j, oc)))) * scale + offset; - #else - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(j, weight + weight_offset)) * scale + offset; - #endif - { - out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0); - } - { - out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0); - } - } -#if INPUT_CHANNEL_LEAVES_NUM != 0 - { - int k2 = loop << 1; - COMPUTE_FLOAT16 scale, offset; - { - #ifdef ASYMMETRIC - COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + (k2 / blockDim) * dstChannelC4 * 8)) / coef); - scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace); - offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf); - #else - COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + (k2 / blockDim) * dstChannelC4 * 4)) / coef); - scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset); - offset = 0; - #endif - } - #ifdef USE_IMAGE - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(loop, oc)))) * scale + offset; - #else - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(j, weight + weight_offset)) * scale + offset; - #endif - { - out0 = mad((COMPUTE_FLOAT8)input[k2], wei.s01234567, out0); - } - } -#endif - sum[lid] = out0; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = WGS/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - if(lid == 0){ - out0 = sum[0] + CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8)); - #ifdef RELU - out0 = fmax(out0, (COMPUTE_FLOAT8)0); - #endif - - #ifdef RELU6 - out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6); - #endif - #ifdef OUTPUT_CHANNEL_LEAVES - vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8); - if(oc8 + 4 < dstChannelC4 * 4) - vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + oc8 + 4); - #else - vstore8(CONVERT_FLOAT8(out0), 0, output + oc8); - #endif - } -} -#else -__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 - __global const FLOAT* input, -#ifdef USE_IMAGE - __read_only image2d_t weight, -#else + #else __global const uchar *weight, + #endif #endif __global const FLOAT *dequantScaleOffset, __global const FLOAT *bias, __global FLOAT* output, + __private const int dstChannelAlign, + __private const int srcChannelAlign, __private const int dstChannelC4, __private const int srcChannelC4, __private const int srcChannel, @@ -283,29 +327,83 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 UNIFORM_BOUNDRY_CHECK_2(ic, oc); const int oc8 = oc << 3; - - const int loop = (blockDim + 4 - 1) / 4; -#if INPUT_CHANNEL_LEAVES_NUM != 0 + +#if QUANT_BIT == 8 + const int loop = (blockDim + 2 - 1) / 2; + #if INPUT_CHANNEL_LEAVES_NUM != 0 const int loop_end = max(loop - 1, 0); -#else + #else const int loop_end = loop; + #endif + #ifndef USE_IMAGE + const int weight_offset = oc * srcChannelC4 * 32; + #endif +#else + const int loop = (blockDim + 4 - 1) / 4; + #if INPUT_CHANNEL_LEAVES_NUM != 0 + const int loop_end = max(loop - 1, 0); + #else + const int loop_end = loop; + #endif + #ifndef USE_IMAGE + const int weight_offset = oc * srcChannelC4 * 16; + #endif #endif COMPUTE_FLOAT8 out0 = CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8)); -#ifndef USE_IMAGE - const int weight_offset = oc * srcChannelC4 * 16; -#endif for (int i = 0; i < blockNum; i++){ -#ifdef ASYMMETRIC + #if QUANT_BIT == 8 + COMPUTE_FLOAT16 scale, offset; + { + #ifdef ASYMMETRIC + COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + i * dstChannelC4 * 8)) / coef); + scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace); + offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf); + #else + COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + i * dstChannelC4 * 4)) / coef); + scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset); + offset = 0; + #endif + } + for (int j = 0; j < loop_end; j++) { + int k = i * loop + j; + COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + (k << 1))); + #ifdef USE_IMAGE + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset; + #else + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset; + #endif + { + out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0); + } + { + out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0); + } + } + #if INPUT_CHANNEL_LEAVES_NUM != 0 + { + int k = i * loop + loop_end; + #ifdef USE_IMAGE + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset; + #else + COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset; + #endif + { + out0 = mad((COMPUTE_FLOAT8)input[k << 1], wei.s01234567, out0); + } + } + #endif + #else + #ifdef ASYMMETRIC COMPUTE_FLOAT8 scale, offset; { COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + i * dstChannelC4 * 8)) / coef); scale = scaleOffset.s02468ace; offset = scaleOffset.s13579bdf; } -#else + #else COMPUTE_FLOAT8 scale = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + i * dstChannelC4 * 4)) / coef); COMPUTE_FLOAT8 offset = 0; -#endif + #endif for (int j = 0; j < loop_end; j++) { int k = i * loop + j; COMPUTE_FLOAT8 wei; @@ -360,94 +458,7 @@ __kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2 #endif } #endif -} -#ifdef RELU - out0 = fmax(out0, (COMPUTE_FLOAT8)0); -#endif - -#ifdef RELU6 - out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6); -#endif - #ifdef OUTPUT_CHANNEL_LEAVES - vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8); - if(oc8 + 4 < dstChannelC4 * 4) - vstore4(CONVERT_FLOAT4(out0.s4567), 0, output + oc8 + 4); - #else - vstore8(CONVERT_FLOAT8(out0), 0, output + oc8); #endif -} - -__kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2 - __global const FLOAT* input, -#ifdef USE_IMAGE - __read_only image2d_t weight, -#else - __global const char *weight, -#endif - __global const FLOAT *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int srcChannel, - __private const int blockNum, - __private const int blockDim, - __private const float coef) { - const int ic = get_global_id(0); - const int oc = get_global_id(1); //oc/8 - UNIFORM_BOUNDRY_CHECK_2(ic, oc); - const int oc8 = oc << 3; - const int loop = (blockDim + 2 - 1) / 2; -#if INPUT_CHANNEL_LEAVES_NUM != 0 - const int loop_end = max(loop - 1, 0); -#else - const int loop_end = loop; -#endif -#ifndef USE_IMAGE - const int weight_offset = oc * srcChannelC4 * 32; -#endif - COMPUTE_FLOAT8 out0 = CONVERT_COMPUTE_FLOAT8(vload8(0, bias + oc8)); - for (int i = 0; i < blockNum; i++){ - COMPUTE_FLOAT16 scale, offset; - { - #ifdef ASYMMETRIC - COMPUTE_FLOAT16 scaleOffset = CONVERT_COMPUTE_FLOAT16(convert_float16(vload16(0, dequantScaleOffset + oc8 * 2 + i * dstChannelC4 * 8)) / coef); - scale = (COMPUTE_FLOAT16)(scaleOffset.s02468ace, scaleOffset.s02468ace); - offset = (COMPUTE_FLOAT16)(scaleOffset.s13579bdf, scaleOffset.s13579bdf); - #else - COMPUTE_FLOAT8 scaleOffset = CONVERT_COMPUTE_FLOAT8(convert_float8(vload8(0, dequantScaleOffset + oc8 + i * dstChannelC4 * 4)) / coef); - scale = (COMPUTE_FLOAT16)(scaleOffset, scaleOffset); - offset = 0; - #endif - } - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - COMPUTE_FLOAT2 in = CONVERT_COMPUTE_FLOAT2(vload2(0, input + (k << 1))); - #ifdef USE_IMAGE - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset; - #else - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset; - #endif - { - out0 = mad((COMPUTE_FLOAT8)in.s0, wei.s01234567, out0); - } - { - out0 = mad((COMPUTE_FLOAT8)in.s1, wei.s89abcdef, out0); - } - } - #if INPUT_CHANNEL_LEAVES_NUM != 0 - { - int k = i * loop + loop_end; - #ifdef USE_IMAGE - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(k, oc)))) * scale + offset; - #else - COMPUTE_FLOAT16 wei = CONVERT_COMPUTE_FLOAT16(vload16(k, weight + weight_offset)) * scale + offset; - #endif - { - out0 = mad((COMPUTE_FLOAT8)input[k << 1], wei.s01234567, out0); - } - } - #endif } #ifdef RELU out0 = fmax(out0, (COMPUTE_FLOAT8)0); @@ -456,7 +467,6 @@ __kernel void gemv_conv_c8_int8_buf(GLOBAL_SIZE_DIM_2 #ifdef RELU6 out0 = clamp(out0, (COMPUTE_FLOAT8)0, (COMPUTE_FLOAT8)6); #endif - #ifdef OUTPUT_CHANNEL_LEAVES vstore4(CONVERT_FLOAT4(out0.s0123), 0, output + oc8); if(oc8 + 4 < dstChannelC4 * 4) diff --git a/source/backend/opencl/execution/cl/gemv_conv1x1_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/gemv_conv1x1_buf_mnn_cl.cpp index c824d4d9..0b65eecc 100644 --- a/source/backend/opencl/execution/cl/gemv_conv1x1_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/gemv_conv1x1_buf_mnn_cl.cpp @@ -7,19 +7,26 @@ const char* gemv_conv1x1_buf = "#endif\n" "__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" "#define GLOBAL_SIZE_DIM_2 "" __private int global_size_dim0,__private int global_size_dim1,\n" +"#define GLOBAL_SIZE_DIM_3 "" __private int global_size_dim0,__private int global_size_dim1,__private int global_size_dim2,\n" "#define UNIFORM_BOUNDRY_CHECK_2(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n" "#define UCHAR4_TO_CHAR8(b, scale, offset) "" wei.s0 = (COMPUTE_FLOAT)((b.s0 >> 4) - 8); "" wei.s1 = (COMPUTE_FLOAT)((b.s0 & 15) - 8); "" wei.s2 = (COMPUTE_FLOAT)((b.s1 >> 4) - 8); "" wei.s3 = (COMPUTE_FLOAT)((b.s1 & 15) - 8); "" wei.s4 = (COMPUTE_FLOAT)((b.s2 >> 4) - 8); "" wei.s5 = (COMPUTE_FLOAT)((b.s2 & 15) - 8); "" wei.s6 = (COMPUTE_FLOAT)((b.s3 >> 4) - 8); "" wei.s7 = (COMPUTE_FLOAT)((b.s3 & 15) - 8); "" wei=wei*scale+offset;\n" "#if WGS >= 8\n" -"__kernel void gemv_conv_c8_int4_buf(GLOBAL_SIZE_DIM_2\n" +"__kernel void gemv_conv_c8_buf(GLOBAL_SIZE_DIM_3\n" " __global const FLOAT* input,\n" "#ifdef USE_IMAGE\n" " __read_only image2d_t weight,\n" "#else\n" +" #if QUANT_BIT == 8\n" +" __global const char *weight,\n" +" #else\n" " __global const uchar *weight,\n" +" #endif\n" "#endif\n" " __global const FLOAT *dequantScaleOffset,\n" " __global const FLOAT *bias,\n" " __global FLOAT* output,\n" +" __private const int dstChannelAlign,\n" +" __private const int srcChannelAlign,\n" " __private const int dstChannelC4,\n" " __private const int srcChannelC4,\n" " __private const int srcChannel,\n" @@ -29,32 +36,100 @@ const char* gemv_conv1x1_buf = " const int lid=get_local_id(0);\n" " const int oc=get_global_id(1); //oc/8\n" " const int oc8=oc << 3;\n" -"#if INPUT_CHANNEL_LEAVES_NUM != 0\n" -" const int loop=max((srcChannel+4-1)/4-1,0);\n" -"#else\n" -" const int loop=(srcChannel+4-1)/4;\n" -"#endif\n" -" __local COMPUTE_FLOAT8 sum[WGS];\n" -" COMPUTE_FLOAT8 out0=0;\n" -"#ifndef USE_IMAGE\n" -" const int weight_offset=oc*srcChannelC4*16;\n" -"#endif\n" " \n" +"#if QUANT_BIT == 8\n" +" #if INPUT_CHANNEL_LEAVES_NUM != 0\n" +" const int loop=max((srcChannel+2-1)/2-1,0);\n" +" #else\n" +" const int loop=(srcChannel+2-1)/2;\n" +" #endif\n" +" #ifndef USE_IMAGE\n" +" const int weight_offset=oc*srcChannelC4*32;\n" +" #endif\n" +"#else\n" +" #if INPUT_CHANNEL_LEAVES_NUM != 0\n" +" const int loop=max((srcChannel+4-1)/4-1,0);\n" +" #else\n" +" const int loop=(srcChannel+4-1)/4;\n" +" #endif\n" +" #ifndef USE_IMAGE\n" +" const int weight_offset=oc*srcChannelC4*16;\n" +" #endif\n" +"#endif\n" +" COMPUTE_FLOAT8 out0=0;\n" +" int input_offset=0,output_offset=oc8;\n" +" __local COMPUTE_FLOAT8 sum0[WGS];\n" +"#ifdef COMPUTE_BATCH\n" +" const int out_b_idx=get_global_id(2) << 2; //b/4\n" +" __local COMPUTE_FLOAT8 sum1[WGS];\n" +" __local COMPUTE_FLOAT8 sum2[WGS];\n" +" __local COMPUTE_FLOAT8 sum3[WGS];\n" +" COMPUTE_FLOAT8 out1=0,out2=0,out3=0;\n" +" input_offset=out_b_idx*srcChannelAlign;\n" +" output_offset=oc8+out_b_idx*dstChannelAlign;\n" +"#endif\n" " for(int j=lid; j0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid= shape.w) return; vstore4(in3, 0, output_ptr + output_offset + 12); + #else + //not support #endif #endif } @@ -167,11 +169,12 @@ __kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS int idx = c * shape.w + w; // c/4*w int idy = nh; // n*h + INPUT_TYPE4 in0, in1, in2, in3; #ifdef USE_IMAGE - INPUT_TYPE4 in0 = RI_DATA(input_ptr, SAMPLER, (int2)(idx, idy)); - INPUT_TYPE4 in1 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+1, idy)); - INPUT_TYPE4 in2 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+2, idy)); - INPUT_TYPE4 in3 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+3, idy)); + in0 = RI_DATA(input_ptr, SAMPLER, (int2)(idx, idy)); + in1 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+1, idy)); + in2 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+2, idy)); + in3 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+3, idy)); #else #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW int input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w; @@ -181,22 +184,24 @@ __kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS tmp1 = vload4(0, input_ptr + input_offset + stride); tmp2 = vload4(0, input_ptr + input_offset + stride + stride); tmp3 = vload4(0, input_ptr + input_offset + stride + stride + stride); - INPUT_TYPE4 in0 = (INPUT_TYPE4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x); - INPUT_TYPE4 in1 = (INPUT_TYPE4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y); - INPUT_TYPE4 in2 = (INPUT_TYPE4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z); - INPUT_TYPE4 in3 = (INPUT_TYPE4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w); + in0 = (INPUT_TYPE4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x); + in1 = (INPUT_TYPE4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y); + in2 = (INPUT_TYPE4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z); + in3 = (INPUT_TYPE4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w); #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC int input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c; - INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset); - INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + shape.y); - INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + shape.y + shape.y); - INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + shape.y + shape.y + shape.y); + in0 = vload4(0, input_ptr + input_offset); + in1 = vload4(0, input_ptr + input_offset + shape.y); + in2 = vload4(0, input_ptr + input_offset + shape.y + shape.y); + in3 = vload4(0, input_ptr + input_offset + shape.y + shape.y + shape.y); #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 int input_offset = (((cblock * shape.x + n) * shape.z + h) * shape.w + w) * 4; - INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset); - INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + 4); - INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + 8); - INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + 12); + in0 = vload4(0, input_ptr + input_offset); + in1 = vload4(0, input_ptr + input_offset + 4); + in2 = vload4(0, input_ptr + input_offset + 8); + in3 = vload4(0, input_ptr + input_offset + 12); + #else + // not support #endif #endif const int offset = idy * shape.w * 4; diff --git a/source/backend/opencl/execution/cl/glmem_convert_mnn_cl.cpp b/source/backend/opencl/execution/cl/glmem_convert_mnn_cl.cpp index c6379279..c3679cb1 100644 --- a/source/backend/opencl/execution/cl/glmem_convert_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/glmem_convert_mnn_cl.cpp @@ -132,6 +132,8 @@ const char* glmem_convert = " vstore4(in2,0,output_ptr+output_offset+8);\n" " if(w+3 >= shape.w) return;\n" " vstore4(in3,0,output_ptr+output_offset+12);\n" +" #else\n" +" //not support\n" " #endif\n" "#endif\n" "}\n" @@ -157,11 +159,12 @@ const char* glmem_convert = " \n" " int idx=c*shape.w+w; // c/4*w\n" " int idy=nh; // n*h\n" +" INPUT_TYPE4 in0,in1,in2,in3;\n" "#ifdef USE_IMAGE\n" -" INPUT_TYPE4 in0=RI_DATA(input_ptr,SAMPLER,(int2)(idx,idy));\n" -" INPUT_TYPE4 in1=RI_DATA(input_ptr,SAMPLER,(int2)(idx+1,idy));\n" -" INPUT_TYPE4 in2=RI_DATA(input_ptr,SAMPLER,(int2)(idx+2,idy));\n" -" INPUT_TYPE4 in3=RI_DATA(input_ptr,SAMPLER,(int2)(idx+3,idy));\n" +" in0=RI_DATA(input_ptr,SAMPLER,(int2)(idx,idy));\n" +" in1=RI_DATA(input_ptr,SAMPLER,(int2)(idx+1,idy));\n" +" in2=RI_DATA(input_ptr,SAMPLER,(int2)(idx+2,idy));\n" +" in3=RI_DATA(input_ptr,SAMPLER,(int2)(idx+3,idy));\n" "#else\n" " #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" " int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n" @@ -171,22 +174,24 @@ const char* glmem_convert = " tmp1=vload4(0,input_ptr+input_offset+stride);\n" " tmp2=vload4(0,input_ptr+input_offset+stride+stride);\n" " tmp3=vload4(0,input_ptr+input_offset+stride+stride+stride);\n" -" INPUT_TYPE4 in0=(INPUT_TYPE4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n" -" INPUT_TYPE4 in1=(INPUT_TYPE4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n" -" INPUT_TYPE4 in2=(INPUT_TYPE4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n" -" INPUT_TYPE4 in3=(INPUT_TYPE4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n" +" in0=(INPUT_TYPE4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n" +" in1=(INPUT_TYPE4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n" +" in2=(INPUT_TYPE4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n" +" in3=(INPUT_TYPE4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n" " #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" " int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n" -" INPUT_TYPE4 in0=vload4(0,input_ptr+input_offset);\n" -" INPUT_TYPE4 in1=vload4(0,input_ptr+input_offset+shape.y);\n" -" INPUT_TYPE4 in2=vload4(0,input_ptr+input_offset+shape.y+shape.y);\n" -" INPUT_TYPE4 in3=vload4(0,input_ptr+input_offset+shape.y+shape.y+shape.y);\n" +" in0=vload4(0,input_ptr+input_offset);\n" +" in1=vload4(0,input_ptr+input_offset+shape.y);\n" +" in2=vload4(0,input_ptr+input_offset+shape.y+shape.y);\n" +" in3=vload4(0,input_ptr+input_offset+shape.y+shape.y+shape.y);\n" " #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" " int input_offset=(((cblock*shape.x+n)*shape.z+h)*shape.w+w)*4;\n" -" INPUT_TYPE4 in0=vload4(0,input_ptr+input_offset);\n" -" INPUT_TYPE4 in1=vload4(0,input_ptr+input_offset+4);\n" -" INPUT_TYPE4 in2=vload4(0,input_ptr+input_offset+8);\n" -" INPUT_TYPE4 in3=vload4(0,input_ptr+input_offset+12);\n" +" in0=vload4(0,input_ptr+input_offset);\n" +" in1=vload4(0,input_ptr+input_offset+4);\n" +" in2=vload4(0,input_ptr+input_offset+8);\n" +" in3=vload4(0,input_ptr+input_offset+12);\n" +" #else\n" +" // not support\n" " #endif\n" "#endif\n" " const int offset=idy*shape.w*4;\n" diff --git a/source/backend/opencl/execution/cl/groupnorm_buf.cl b/source/backend/opencl/execution/cl/groupnorm_buf.cl index 0e2dde0f..e8b0a4fb 100644 --- a/source/backend/opencl/execution/cl/groupnorm_buf.cl +++ b/source/backend/opencl/execution/cl/groupnorm_buf.cl @@ -2,6 +2,7 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif +#if LOCAL_SIZE > 1 __kernel void groupnorm_plain_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, #ifdef DOUBLE_INPUTS __global const FLOAT * input0, @@ -242,3 +243,4 @@ __kernel void groupnorm_plain_buf(__private int global_dim0, __private int globa #endif } } +#endif diff --git a/source/backend/opencl/execution/cl/groupnorm_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/groupnorm_buf_mnn_cl.cpp index f4eb684d..dbba3336 100644 --- a/source/backend/opencl/execution/cl/groupnorm_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/groupnorm_buf_mnn_cl.cpp @@ -5,6 +5,7 @@ const char* groupnorm_buf = "#ifdef MNN_SUPPORT_FP16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#endif\n" +"#if LOCAL_SIZE>1\n" "__kernel void groupnorm_plain_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" "#ifdef DOUBLE_INPUTS\n" " __global const FLOAT*input0,\n" @@ -226,6 +227,7 @@ const char* groupnorm_buf = "#endif\n" " }\n" "}\n" +"#endif\n" ; #endif } diff --git a/source/backend/opencl/execution/cl/input_transe_buf.cl b/source/backend/opencl/execution/cl/input_transe_buf.cl index 1b86b6e7..091c3e41 100644 --- a/source/backend/opencl/execution/cl/input_transe_buf.cl +++ b/source/backend/opencl/execution/cl/input_transe_buf.cl @@ -49,8 +49,9 @@ __kernel void conv_transe_c4_c1( w * output_x_pitch; FLOAT4 value = vload4(0, input + input_offset); + FLOAT *value_ptr = (FLOAT*)&value; for(int i = 0; i < 4 && cout + i < input_channel; ++i){ - output[output_offset + i * output_f_pitch] = value[i]; + output[output_offset + i * output_f_pitch] = value_ptr[i]; } } diff --git a/source/backend/opencl/execution/cl/input_transe_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/input_transe_buf_mnn_cl.cpp index 13a83c6d..292a7ed6 100644 --- a/source/backend/opencl/execution/cl/input_transe_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/input_transe_buf_mnn_cl.cpp @@ -48,8 +48,9 @@ const char* input_transe_buf = " w*output_x_pitch;\n" " \n" " FLOAT4 value=vload4(0,input+input_offset);\n" +" FLOAT *value_ptr=(FLOAT*)&value;\n" " for(int i=0; i<4 && cout+i (int4)0) || (out > (int4)0 && in1 < (int4)0)) ? out + in1 : out; #else - float4 out = LOOP_BINARY_OPERATOR; + float4 out = OPERATOR; #endif WI_DATA(output, (int2)(co * dst_width + wo, no * dst_height + ho), CONVERT_OUTPUT_I4(out)); } } -#ifdef COMPUTE_CUMSUM + __kernel void loop_cumsum(__private int global_dim0, __private int global_dim1, __private int global_dim2, __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, __private const int input0Stride0, @@ -455,21 +457,20 @@ __kernel void loop_cumsum(__private int global_dim0, __private int global_dim1, int inputIndex1 = z * input1Stride0 + y * input1Stride1 + x * input1Stride2; int outputIndex = z * outputStride0 + y * outputStride1 + x * outputStride2; - float in0 = 0; + float4 in0 = 0; if(offsets.z != offsets.y){ - in0 = (float)input0[inputIndex0]; + in0.x = (float)input0[inputIndex0]; } for(int i = 0; i < loopNumber; ++i){ int4 offset = (int4)i * steps + offsets; - float in1 = (float)input1[inputIndex1 + offset.z]; - float out = LOOP_BINARY_OPERATOR; + float4 in1; + in1.x = (float)input1[inputIndex1 + offset.z]; + float4 out = OPERATOR; - output[outputIndex + offset.x] = (OUTPUT_TYPE)out; - in0 = out; + output[outputIndex + offset.x] = (OUTPUT_TYPE)out.x; + in0.x = out.x; } } } -#endif -#endif diff --git a/source/backend/opencl/execution/cl/loop_buf.cl b/source/backend/opencl/execution/cl/loop_buf.cl index 5bbe5d98..72392813 100644 --- a/source/backend/opencl/execution/cl/loop_buf.cl +++ b/source/backend/opencl/execution/cl/loop_buf.cl @@ -387,7 +387,9 @@ __kernel void pack_buf(__private int global_dim0, __private int global_dim1, __p } } -#ifdef LOOP_BINARY_OPERATOR +#ifndef OPERATOR + #define OPERATOR in0 + in1 +#endif __kernel void loop_binary_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, __private const int input0Stride0, @@ -414,11 +416,11 @@ __kernel void loop_binary_buf(__private int global_dim0, __private int global_di int in0 = (int)input0[inputIndex0]; int in1 = (int)input1[inputIndex1]; int out = in0 % in1; - out = ((out < (int4)0 && in1 > (int4)0) || (out > (int4)0 && in1 < (int4)0)) ? out + in1 : out; + out = ((out < 0 && in1 > 0) || (out > 0 && in1 < 0)) ? out + in1 : out; #else float in0 = (float)input0[inputIndex0]; float in1 = (float)input1[inputIndex1]; - float out = LOOP_BINARY_OPERATOR; + float out = OPERATOR; #endif output[outputIndex] = (OUTPUT_TYPE)out; } @@ -458,11 +460,10 @@ __kernel void loop_cumsum_buf(__private int global_dim0, __private int global_di for(int i = 0; i < loopNumber; ++i){ int4 offset = (int4)i * steps + offsets; float in1 = (float)input1[inputIndex1 + offset.z]; - float out = LOOP_BINARY_OPERATOR; + float out = OPERATOR; output[outputIndex + offset.x] = (OUTPUT_TYPE)out; in0 = out; } } } -#endif diff --git a/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp index 4a0400e8..d23338e4 100644 --- a/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp @@ -379,7 +379,9 @@ const char* loop_buf = " vstore4(value,0,output+b*b_dst_pitch+c_4*c_dst_pitch+h*y_dst_pitch+w*x_dst_pitch);\n" " }\n" "}\n" -"#ifdef LOOP_BINARY_OPERATOR\n" +"#ifndef OPERATOR\n" +" #define OPERATOR in0+in1\n" +"#endif\n" "__kernel void loop_binary_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" " __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n" " __private const int input0Stride0,\n" @@ -406,11 +408,11 @@ const char* loop_buf = " int in0=(int)input0[inputIndex0];\n" " int in1=(int)input1[inputIndex1];\n" " int out=in0 % in1;\n" -" out=((out<(int4)0 && in1>(int4)0) || (out>(int4)0 && in1<(int4)0)) ? out+in1 : out;\n" +" out=((out<0 && in1>0) || (out>0 && in1<0)) ? out+in1 : out;\n" " #else\n" " float in0=(float)input0[inputIndex0];\n" " float in1=(float)input1[inputIndex1];\n" -" float out=LOOP_BINARY_OPERATOR;\n" +" float out=OPERATOR;\n" " #endif\n" " output[outputIndex]=(OUTPUT_TYPE)out;\n" " }\n" @@ -449,14 +451,13 @@ const char* loop_buf = " for(int i=0; i(int4)0) || (out>(int4)0 && in1<(int4)0)) ? out+in1 : out;\n" " #else\n" -" float4 out=LOOP_BINARY_OPERATOR;\n" +" float4 out=OPERATOR;\n" " #endif\n" " \n" " WI_DATA(output,(int2)(co*dst_width+wo,no*dst_height+ho),CONVERT_OUTPUT_I4(out));\n" " }\n" "}\n" -"#ifdef COMPUTE_CUMSUM\n" "__kernel void loop_cumsum(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" " __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n" " __private const int input0Stride0,\n" @@ -442,22 +443,21 @@ const char* loop = " int inputIndex1=z*input1Stride0+y*input1Stride1+x*input1Stride2;\n" " int outputIndex=z*outputStride0+y*outputStride1+x*outputStride2;\n" " \n" -" float in0=0;\n" +" float4 in0=0;\n" " if(offsets.z != offsets.y){\n" -" in0=(float)input0[inputIndex0];\n" +" in0.x=(float)input0[inputIndex0];\n" " }\n" " \n" " for(int i=0; i OpenCLProgramMap = #ifndef MNN_OPENCL_BUFFER_CLOSED { "self_attention_buf", self_attention_buf }, #endif - { "performance", performance }, { "winogradTransformSource2_3_1", winogradTransformSource2_3_1 }, #ifndef MNN_OPENCL_BUFFER_CLOSED { "gemv_conv1x1_buf", gemv_conv1x1_buf }, @@ -262,6 +262,7 @@ const std::map OpenCLProgramMap = #ifndef MNN_OPENCL_BUFFER_CLOSED { "gemm_buf", gemm_buf }, #endif + { "conv_2d_int", conv_2d_int }, { "copy_buffer_to_image2d", copy_buffer_to_image2d }, { "loop", loop }, #ifndef MNN_OPENCL_BUFFER_CLOSED @@ -296,6 +297,7 @@ const std::map OpenCLProgramMap = #ifndef MNN_OPENCL_BUFFER_CLOSED { "conv_2d_buf", conv_2d_buf }, #endif + { "gemm_int", gemm_int }, { "buffer_to_image", buffer_to_image }, { "winogradTransformDest2_3_1", winogradTransformDest2_3_1 }, #ifndef MNN_OPENCL_BUFFER_CLOSED diff --git a/source/backend/opencl/execution/cl/performance.cl b/source/backend/opencl/execution/cl/performance.cl deleted file mode 100644 index 982dc50a..00000000 --- a/source/backend/opencl/execution/cl/performance.cl +++ /dev/null @@ -1,95 +0,0 @@ - -#define MAD_V4(x, y) \ - x = mad(y, x, y); \ - y = mad(x, y, x); \ - x = mad(y, x, y); \ - y = mad(x, y, x); -#define MAD_V16(x, y) \ - MAD_V4(x, y); \ - MAD_V4(x, y); \ - MAD_V4(x, y); \ - MAD_V4(x, y); -#define MAD_V64(x, y) \ - MAD_V16(x, y); \ - MAD_V16(x, y); \ - MAD_V16(x, y); \ - MAD_V16(x, y); -#define MAD_V128(x, y) \ - MAD_V64(x, y); \ - MAD_V64(x, y); \ - MAD_V64(x, y); \ - MAD_V64(x, y); -#define MAD_V256(x, y) \ - MAD_V128(x, y); \ - MAD_V128(x, y); \ - MAD_V128(x, y); \ - MAD_V128(x, y); - -#ifdef MNN_SUPPORT_FP16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#endif - -__kernel void float_precision(__global float* output_ptr, float mul_value) { - float mul_x = mul_value; - float mul_y = (float)get_local_id(0); - - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - output_ptr[get_global_id(0)] = mul_y; -} - -__kernel void half4_precision(__global half* output_ptr, float mul_value) { - half mul = (half)mul_value; - half4 mul_x = (half4)(mul); - half4 mul_y = (half4)get_local_id(0); - - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - MAD_V256(mul_x, mul_y); - - output_ptr[get_global_id(0)] = (mul_y.S0) + (mul_y.S1) + (mul_y.S2) + (mul_y.S3); -} diff --git a/source/backend/opencl/execution/cl/performance_mnn_cl.cpp b/source/backend/opencl/execution/cl/performance_mnn_cl.cpp deleted file mode 100644 index 04b42282..00000000 --- a/source/backend/opencl/execution/cl/performance_mnn_cl.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include "opencl_source_map.hpp" -namespace MNN { -const char* performance = -"#define MAD_V4(x, y) "" x = mad(y, x, y); "" y = mad(x, y, x); "" x = mad(y, x, y); "" y=mad(x,y,x);\n" -"#define MAD_V16(x, y) "" MAD_V4(x, y); "" MAD_V4(x, y); "" MAD_V4(x, y); "" MAD_V4(x,y);\n" -"#define MAD_V64(x, y) "" MAD_V16(x, y); "" MAD_V16(x, y); "" MAD_V16(x, y); "" MAD_V16(x,y);\n" -"#define MAD_V128(x, y) "" MAD_V64(x, y); "" MAD_V64(x, y); "" MAD_V64(x, y); "" MAD_V64(x,y);\n" -"#define MAD_V256(x, y) "" MAD_V128(x, y); "" MAD_V128(x, y); "" MAD_V128(x, y); "" MAD_V128(x,y);\n" -"#ifdef MNN_SUPPORT_FP16\n" -"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" -"#endif\n" -"__kernel void float_precision(__global float* output_ptr,float mul_value) {\n" -" float mul_x=mul_value;\n" -" float mul_y=(float)get_local_id(0);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" output_ptr[get_global_id(0)]=mul_y;\n" -"}\n" -"__kernel void half4_precision(__global half* output_ptr,float mul_value) {\n" -" half mul=(half)mul_value;\n" -" half4 mul_x=(half4)(mul);\n" -" half4 mul_y=(half4)get_local_id(0);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" MAD_V256(mul_x,mul_y);\n" -" output_ptr[get_global_id(0)]=(mul_y.S0)+(mul_y.S1)+(mul_y.S2)+(mul_y.S3);\n" -"}\n" -; -} diff --git a/source/backend/opencl/execution/cl/pooling.cl b/source/backend/opencl/execution/cl/pooling.cl index fc32874f..b0a45937 100644 --- a/source/backend/opencl/execution/cl/pooling.cl +++ b/source/backend/opencl/execution/cl/pooling.cl @@ -29,6 +29,9 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, const int input_width_start = mad24(output_width_idx, stride_shape.y, -pad_shape.y); const int input_channel_start = mul24(output_channel_idx, input_shape.y); + #ifdef RETURN_REDICE + int4 redice = (int4)0; + #endif #ifdef POOL_AVG FLOAT4 output_result = 0; for (int height = 0; height < kernel_shape.x; height++) { @@ -58,9 +61,6 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, output_result = output_result * block_float_req; #else FLOAT4 output_result = (FLOAT4)(-FLT_MAX); - #if RETURN_REDICE - int4 redice = (int4)0; - #endif for (int height = 0; height < kernel_shape.x; height++) { int input_height_idx = input_height_start + height; input_height_idx = @@ -73,7 +73,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, if (input_width_idx != -1) { FLOAT4 input_data = RI_F(input, SAMPLER, (int2)(input_width_idx, input_height_idx)); - #if RETURN_REDICE + #ifdef RETURN_REDICE redice = input_data > output_result ? (int4)((input_height_start + height) * input_shape.y + input_width_start + width) : redice; #endif output_result = fmax(output_result, input_data); @@ -85,12 +85,12 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, const int output_channel_width_idx = mad24(output_channel_idx, output_width, output_width_idx); WI_F(output, (int2)(output_channel_width_idx, output_batch_height_idx), output_result); - #if RETURN_REDICE + #ifdef RETURN_REDICE WI_F(rediceOutput, (int2)(output_channel_width_idx, output_batch_height_idx), CONVERT_FLOAT4(redice)); #endif } -#ifdef LOCAL_SIZE +#if LOCAL_SIZE > 1 __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __private const int2 input_shape, __private const int output_height, __private const int2 pad_shape, __private const int2 stride_shape, @@ -105,10 +105,10 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, FLOAT4 output_result = 0; #else FLOAT4 output_result = (FLOAT4)(-FLT_MAX); -#if RETURN_REDICE +#endif +#ifdef RETURN_REDICE int4 redice = (int4)0; int4 local rediceId[LOCAL_SIZE]; -#endif #endif FLOAT4 local sum_mnn[LOCAL_SIZE]; @@ -122,14 +122,14 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, output_result += in; #else output_result = fmax(output_result, in); -#if RETURN_REDICE +#ifdef RETURN_REDICE redice = in > output_result ? (int4)(i) : redice; #endif #endif } sum_mnn[local_id] = output_result; -#if RETURN_REDICE +#ifdef RETURN_REDICE rediceId[local_id] = redice; #endif barrier(CLK_LOCAL_MEM_FENCE); @@ -139,7 +139,7 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, sum_mnn[local_id] = sum_mnn[local_id] + sum_mnn[local_id + i]; #else { -#if RETURN_REDICE +#ifdef RETURN_REDICE rediceId[local_id] = sum_mnn[local_id] > sum_mnn[local_id + i] ? rediceId[local_id] : rediceId[local_id + i]; #endif sum_mnn[local_id] = fmax(sum_mnn[local_id], sum_mnn[local_id + i]); @@ -153,7 +153,7 @@ __kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, #endif WI_F(output, (int2)(output_channel_idx, output_batch_idx), output_result); - #if RETURN_REDICE + #ifdef RETURN_REDICE redice = rediceId[0]; WI_F(rediceOutput, (int2)(output_channel_idx, output_batch_idx), CONVERT_FLOAT4(redice)); #endif diff --git a/source/backend/opencl/execution/cl/pooling_buf.cl b/source/backend/opencl/execution/cl/pooling_buf.cl index 07db4f80..3e06c794 100644 --- a/source/backend/opencl/execution/cl/pooling_buf.cl +++ b/source/backend/opencl/execution/cl/pooling_buf.cl @@ -29,6 +29,9 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, const int iw_start = mad24(ow_idx, stride_shape.y, -pad_shape.y); const int ih_start = mad24(oh_idx, stride_shape.x, -pad_shape.x); + #ifdef RETURN_REDICE + int4 redice = (int4)0; + #endif #ifdef POOL_AVG COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(0); const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4; @@ -57,9 +60,6 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, result = result / (COMPUTE_FLOAT4)(1.0*total_count); #else COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(-FLT_MAX); - #if RETURN_REDICE - int4 redice = (int4)0; - #endif const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4; for(int kh=0; kh result ? (int4)((ih_start + kh) * input_shape.y + iw_start + kw) : redice; #endif result = fmax(result, inp_data); @@ -82,7 +82,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, const int out_offset = (((b_idx + c_idx*batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4; vstore4(CONVERT_FLOAT4(result), 0, output+out_offset); - #if RETURN_REDICE + #ifdef RETURN_REDICE vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+out_offset); #endif } @@ -105,10 +105,10 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, COMPUTE_FLOAT4 output_result = 0; #else COMPUTE_FLOAT4 output_result = (COMPUTE_FLOAT4)(-FLT_MAX); -#if RETURN_REDICE +#endif +#ifdef RETURN_REDICE int4 redice = (int4)0; int4 local rediceId[LOCAL_SIZE]; -#endif #endif COMPUTE_FLOAT4 local sum_mnn[LOCAL_SIZE]; @@ -122,14 +122,14 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, output_result += in; #else output_result = fmax(output_result, in); -#if RETURN_REDICE +#ifdef RETURN_REDICE redice = in > output_result ? (int4)(i) : redice; #endif #endif } sum_mnn[local_id] = output_result; -#if RETURN_REDICE +#ifdef RETURN_REDICE rediceId[local_id] = redice; #endif barrier(CLK_LOCAL_MEM_FENCE); @@ -139,7 +139,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, sum_mnn[local_id] = sum_mnn[local_id] + sum_mnn[local_id + i]; #else { -#if RETURN_REDICE +#ifdef RETURN_REDICE rediceId[local_id] = sum_mnn[local_id] > sum_mnn[local_id + i] ? rediceId[local_id] : rediceId[local_id + i]; #endif sum_mnn[local_id] = fmax(sum_mnn[local_id], sum_mnn[local_id + i]); @@ -154,7 +154,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, const int out_offset = (output_batch_idx + output_channel_idx*batch)*4; vstore4(CONVERT_FLOAT4(output_result), 0, output+out_offset); -#if RETURN_REDICE +#ifdef RETURN_REDICE redice = rediceId[0]; vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+out_offset); #endif diff --git a/source/backend/opencl/execution/cl/pooling_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/pooling_buf_mnn_cl.cpp index b3f7df23..f0581011 100644 --- a/source/backend/opencl/execution/cl/pooling_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/pooling_buf_mnn_cl.cpp @@ -27,6 +27,9 @@ const char* pooling_buf = " const int iw_start=mad24(ow_idx,stride_shape.y,-pad_shape.y);\n" " const int ih_start=mad24(oh_idx,stride_shape.x,-pad_shape.x);\n" " \n" +" #ifdef RETURN_REDICE\n" +" int4 redice=(int4)0;\n" +" #endif\n" " #ifdef POOL_AVG\n" " COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(0);\n" " const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;\n" @@ -55,9 +58,6 @@ const char* pooling_buf = " result=result/(COMPUTE_FLOAT4)(1.0*total_count);\n" " #else\n" " COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(-FLT_MAX);\n" -" #if RETURN_REDICE\n" -" int4 redice=(int4)0;\n" -" #endif\n" " const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4;\n" " for(int kh=0; khresult ? (int4)((ih_start+kh)*input_shape.y+iw_start+kw) : redice;\n" " #endif\n" " result=fmax(result,inp_data);\n" @@ -80,7 +80,7 @@ const char* pooling_buf = " \n" " const int out_offset=(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n" " vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n" -" #if RETURN_REDICE\n" +" #ifdef RETURN_REDICE\n" " vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n" " #endif\n" "}\n" @@ -101,11 +101,11 @@ const char* pooling_buf = " COMPUTE_FLOAT4 output_result=0;\n" "#else\n" " COMPUTE_FLOAT4 output_result=(COMPUTE_FLOAT4)(-FLT_MAX);\n" -"#if RETURN_REDICE\n" +"#endif\n" +"#ifdef RETURN_REDICE\n" " int4 redice=(int4)0;\n" " int4 local rediceId[LOCAL_SIZE];\n" "#endif\n" -"#endif\n" " COMPUTE_FLOAT4 local sum_mnn[LOCAL_SIZE];\n" " const int inp_offset=((output_batch_idx+output_channel_idx*batch)*input_shape.x)*input_shape.y*4;\n" " const int size=input_shape.x*input_shape.y;\n" @@ -117,14 +117,14 @@ const char* pooling_buf = " output_result += in;\n" "#else\n" " output_result=fmax(output_result,in);\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " redice=in>output_result ? (int4)(i) : redice;\n" "#endif\n" "#endif\n" " }\n" " \n" " sum_mnn[local_id]=output_result;\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " rediceId[local_id]=redice;\n" "#endif\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" @@ -134,7 +134,7 @@ const char* pooling_buf = " sum_mnn[local_id]=sum_mnn[local_id]+sum_mnn[local_id+i];\n" "#else\n" " {\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " rediceId[local_id]=sum_mnn[local_id]>sum_mnn[local_id+i] ? rediceId[local_id] : rediceId[local_id+i];\n" "#endif\n" " sum_mnn[local_id]=fmax(sum_mnn[local_id],sum_mnn[local_id+i]);\n" @@ -148,7 +148,7 @@ const char* pooling_buf = "#endif\n" " const int out_offset=(output_batch_idx+output_channel_idx*batch)*4;\n" " vstore4(CONVERT_FLOAT4(output_result),0,output+out_offset);\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " redice=rediceId[0];\n" " vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n" "#endif\n" diff --git a/source/backend/opencl/execution/cl/pooling_mnn_cl.cpp b/source/backend/opencl/execution/cl/pooling_mnn_cl.cpp index b0c76a88..3d865d29 100644 --- a/source/backend/opencl/execution/cl/pooling_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/pooling_mnn_cl.cpp @@ -24,6 +24,9 @@ const char* pooling = " const int input_height_start=mad24(output_height_idx,stride_shape.x,-pad_shape.x);\n" " const int input_width_start=mad24(output_width_idx,stride_shape.y,-pad_shape.y);\n" " const int input_channel_start=mul24(output_channel_idx,input_shape.y);\n" +" #ifdef RETURN_REDICE\n" +" int4 redice=(int4)0;\n" +" #endif\n" "#ifdef POOL_AVG\n" " FLOAT4 output_result=0;\n" " for (int height=0; heightoutput_result ? (int4)((input_height_start+height)*input_shape.y+input_width_start+width) : redice;\n" " #endif\n" " output_result=fmax(output_result,input_data);\n" @@ -76,11 +76,11 @@ const char* pooling = "#endif\n" " const int output_channel_width_idx=mad24(output_channel_idx,output_width,output_width_idx);\n" " WI_F(output,(int2)(output_channel_width_idx,output_batch_height_idx),output_result);\n" -" #if RETURN_REDICE\n" +" #ifdef RETURN_REDICE\n" " WI_F(rediceOutput,(int2)(output_channel_width_idx,output_batch_height_idx),CONVERT_FLOAT4(redice));\n" " #endif\n" "}\n" -"#ifdef LOCAL_SIZE\n" +"#if LOCAL_SIZE>1\n" "__kernel void global_pooling(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,\n" " __private const int2 input_shape,__private const int output_height,__private const int2 pad_shape,\n" " __private const int2 stride_shape,\n" @@ -94,11 +94,11 @@ const char* pooling = " FLOAT4 output_result=0;\n" "#else\n" " FLOAT4 output_result=(FLOAT4)(-FLT_MAX);\n" -"#if RETURN_REDICE\n" +"#endif\n" +"#ifdef RETURN_REDICE\n" " int4 redice=(int4)0;\n" " int4 local rediceId[LOCAL_SIZE];\n" "#endif\n" -"#endif\n" " FLOAT4 local sum_mnn[LOCAL_SIZE];\n" " int wc=output_channel_idx*input_shape.y;\n" " int bh=output_batch_idx*input_shape.x;\n" @@ -110,14 +110,14 @@ const char* pooling = " output_result += in;\n" "#else\n" " output_result=fmax(output_result,in);\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " redice=in>output_result ? (int4)(i) : redice;\n" "#endif\n" "#endif\n" " }\n" " \n" " sum_mnn[local_id]=output_result;\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " rediceId[local_id]=redice;\n" "#endif\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" @@ -127,7 +127,7 @@ const char* pooling = " sum_mnn[local_id]=sum_mnn[local_id]+sum_mnn[local_id+i];\n" "#else\n" " {\n" -"#if RETURN_REDICE\n" +"#ifdef RETURN_REDICE\n" " rediceId[local_id]=sum_mnn[local_id]>sum_mnn[local_id+i] ? rediceId[local_id] : rediceId[local_id+i];\n" "#endif\n" " sum_mnn[local_id]=fmax(sum_mnn[local_id],sum_mnn[local_id+i]);\n" @@ -140,7 +140,7 @@ const char* pooling = " output_result /= (input_shape.x*input_shape.y);\n" "#endif\n" " WI_F(output,(int2)(output_channel_idx,output_batch_idx),output_result);\n" -" #if RETURN_REDICE\n" +" #ifdef RETURN_REDICE\n" " redice=rediceId[0];\n" " WI_F(rediceOutput,(int2)(output_channel_idx,output_batch_idx),CONVERT_FLOAT4(redice));\n" " #endif\n" diff --git a/source/backend/opencl/execution/cl/raster_buf.cl b/source/backend/opencl/execution/cl/raster_buf.cl index 94791008..8a41313d 100644 --- a/source/backend/opencl/execution/cl/raster_buf.cl +++ b/source/backend/opencl/execution/cl/raster_buf.cl @@ -69,28 +69,30 @@ __kernel void raster_direct_buffer( int inputIndex = inputOffset + id * combineSrcOffset + z * inputStride0 + y * inputStride1 + x * inputStride2; int outputIndex = outputOffset + id * combineDstOffset + z * outputStride0 + y * outputStride1 + x * outputStride2; + int inputIndexReal = 0; + int outputIndexReal = 0; #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW - int inputIndexReal = inputIndex; + inputIndexReal = inputIndex; #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC - int inputIndexReal = inputIndex; + inputIndexReal = inputIndex; #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 int in_w = inputIndex % src_width; inputIndex /= src_width; int in_h = inputIndex % src_height; inputIndex /= src_height; int in_c = inputIndex % src_channel; int in_b = inputIndex / src_channel; - int inputIndexReal = (((in_b + (in_c / 4) * src_batch) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4); + inputIndexReal = (((in_b + (in_c / 4) * src_batch) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4); #endif #if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW - int outputIndexReal = outputIndex; + outputIndexReal = outputIndex; #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC - int outputIndexReal = outputIndex; + outputIndexReal = outputIndex; #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 int out_w = outputIndex % dst_width; outputIndex /= dst_width; int out_h = outputIndex % dst_height; outputIndex /= dst_height; int out_c = outputIndex % dst_channel; int out_b = outputIndex / dst_channel; - int outputIndexReal = (((out_b + (out_c / 4) * dst_batch) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4); + outputIndexReal = (((out_b + (out_c / 4) * dst_batch) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4); #endif output[outputIndexReal] = (OUTPUT_TYPE)input[inputIndexReal]; } diff --git a/source/backend/opencl/execution/cl/raster_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/raster_buf_mnn_cl.cpp index c44e8974..557e2cc1 100644 --- a/source/backend/opencl/execution/cl/raster_buf_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/raster_buf_mnn_cl.cpp @@ -58,28 +58,30 @@ const char* raster_buf = " \n" " int inputIndex=inputOffset+id*combineSrcOffset+z*inputStride0+y*inputStride1+x*inputStride2;\n" " int outputIndex=outputOffset+id*combineDstOffset+z*outputStride0+y*outputStride1+x*outputStride2;\n" +" int inputIndexReal=0;\n" +" int outputIndexReal=0;\n" "#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" -" int inputIndexReal=inputIndex;\n" +" inputIndexReal=inputIndex;\n" "#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" -" int inputIndexReal=inputIndex;\n" +" inputIndexReal=inputIndex;\n" "#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" " int in_w=inputIndex % src_width; inputIndex /= src_width;\n" " int in_h=inputIndex % src_height; inputIndex /= src_height;\n" " int in_c=inputIndex % src_channel;\n" " int in_b=inputIndex/src_channel;\n" -" int inputIndexReal=(((in_b+(in_c/4)*src_batch)*src_height+in_h)*src_width+in_w)*4+(in_c % 4);\n" +" inputIndexReal=(((in_b+(in_c/4)*src_batch)*src_height+in_h)*src_width+in_w)*4+(in_c % 4);\n" "#endif\n" " \n" "#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" -" int outputIndexReal=outputIndex;\n" +" outputIndexReal=outputIndex;\n" "#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" -" int outputIndexReal=outputIndex;\n" +" outputIndexReal=outputIndex;\n" "#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" " int out_w=outputIndex % dst_width; outputIndex /= dst_width;\n" " int out_h=outputIndex % dst_height; outputIndex /= dst_height;\n" " int out_c=outputIndex % dst_channel;\n" " int out_b=outputIndex/dst_channel;\n" -" int outputIndexReal=(((out_b+(out_c/4)*dst_batch)*dst_height+out_h)*dst_width+out_w)*4+(out_c % 4);\n" +" outputIndexReal=(((out_b+(out_c/4)*dst_batch)*dst_height+out_h)*dst_width+out_w)*4+(out_c % 4);\n" "#endif\n" " output[outputIndexReal]=(OUTPUT_TYPE)input[inputIndexReal];\n" "}\n" diff --git a/source/backend/opencl/execution/cl/reduction.cl b/source/backend/opencl/execution/cl/reduction.cl index 11fb5568..93f9fe0f 100644 --- a/source/backend/opencl/execution/cl/reduction.cl +++ b/source/backend/opencl/execution/cl/reduction.cl @@ -18,7 +18,7 @@ __private const int global_size_dim0, __private const int global_size_dim1, __pr if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ return; \ } - + __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; diff --git a/source/backend/opencl/execution/cl/reduction_mnn_cl.cpp b/source/backend/opencl/execution/cl/reduction_mnn_cl.cpp index b89ce3e3..281989ab 100644 --- a/source/backend/opencl/execution/cl/reduction_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/reduction_mnn_cl.cpp @@ -12,6 +12,7 @@ const char* reduction = "#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n" "#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" "#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +" \n" "__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" "__kernel void reduct_width(GLOBAL_SIZE_3_DIMS\n" " __read_only image2d_t input,\n" diff --git a/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp b/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp index dc7f338a..27d9fd67 100644 --- a/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp @@ -40,7 +40,7 @@ void ConvLowMemoryExecution::getInfoFromOpLowMemory(std::shared_ptrmOutputChannel; mResource->mBlockSize = totalCount / numAlpha; // set mDequantScale mDequantOffset - int numAlphaPack = ROUND_UP(numAlpha, 16); + int numAlphaPack = ROUND_UP(numAlpha, 4); int mapSize = mResource->mBlockSize * numAlphaPack * sizeof(int32_t) * 2; mResource->dequantScaleOffset.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, mapSize)); // transfer data from src in cpu to dst in gpu @@ -242,7 +242,7 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision()); + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; @@ -283,7 +283,7 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision()); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision()); uint32_t idx = 0; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); @@ -347,7 +347,7 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision()); + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[knl_idx], buildOption, mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; @@ -391,7 +391,7 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision()); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_int", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision()); uint32_t idx = 0; cl_int ret = CL_SUCCESS; @@ -445,7 +445,7 @@ void ConvLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * output) if(inputChannels % 4 != 0){ buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm", kernelname, buildOption, mOpenCLBackend->getPrecision()); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_int", kernelname, buildOption, mOpenCLBackend->getPrecision()); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); mGlobalWorkSize = {static_cast(global_x), static_cast(global_y)}; // MNN_PRINT("Kernel is %d.\n", min_index); @@ -516,13 +516,7 @@ ConvLowMemoryExecution::ConvLowMemoryExecution(const std::vector &inpu } else if (conv2dCommonParams->relu6()) { mResource->mBuildOptions.emplace("-DRELU6"); } - if (mNumQuantBit == 8) { - // int8 case - mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); - } else if (mNumQuantBit == 4){ - // int4 case - mResource->mBuildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); - } else {/* More types to be supported. */} + mResource->mBuildOptions.emplace("-DQUANT_BIT=" + std::to_string(mNumQuantBit)); #ifdef LOG_VERBOSE MNN_PRINT("end ConvExecution init !\n"); #endif diff --git a/source/backend/opencl/execution/image/LoopExecution.cpp b/source/backend/opencl/execution/image/LoopExecution.cpp index a4897387..be698599 100644 --- a/source/backend/opencl/execution/image/LoopExecution.cpp +++ b/source/backend/opencl/execution/image/LoopExecution.cpp @@ -519,7 +519,7 @@ LoopBinaryExecution::LoopBinaryExecution(const LoopParam *loop, const std::strin : CommonExecution(bn, op) { mLoop = loop; mTensors.resize(mLoop->tensorNumber()); - mBuildOptions.emplace("-DLOOP_BINARY_OPERATOR=" + compute); + mBuildOptions.emplace("-DOPERATOR=" + compute); } ErrorCode LoopBinaryExecution::cumSumOnEncode(const std::vector &inputs, const std::vector &outputs) { @@ -569,7 +569,6 @@ ErrorCode LoopBinaryExecution::cumSumOnEncode(const std::vector &input { Unit unit; std::set buildOptions = mBuildOptions; - buildOptions.emplace("-DCOMPUTE_CUMSUM"); unit.kernel = runTime->buildKernel("loop", "loop_cumsum", buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); diff --git a/source/backend/opencl/execution/image/PoolExecution.cpp b/source/backend/opencl/execution/image/PoolExecution.cpp index 9c4cc929..af39fe4c 100644 --- a/source/backend/opencl/execution/image/PoolExecution.cpp +++ b/source/backend/opencl/execution/image/PoolExecution.cpp @@ -80,7 +80,7 @@ ErrorCode PoolExecution::onEncode(const std::vector &inputs, const std std::set buildOptions; std::string kernelName = "pooling"; auto runtime = mOpenCLBackend->getOpenCLRuntime(); - int local_size; + int local_size = 1; if (mPoolParams->isGlobal()) { std::vector inputShape = tensorShapeFormat(inputs[0]); @@ -90,8 +90,8 @@ ErrorCode PoolExecution::onEncode(const std::vector &inputs, const std kernelName = "global_pooling"; auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize); local_size = getLocalSize(inputShape.at(1) * inputShape.at(2), MaxLocalSize); - buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); } + buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); if (mPadType == PoolPadType_SAME) { int padNeededHeight = std::max(0, (output->height() - 1) * mStrides[0] + mKernels[0] - input->height()); diff --git a/source/backend/opencl/schema/CLCache.fbs b/source/backend/opencl/schema/CLCache.fbs index a6fd6c3f..8c507709 100644 --- a/source/backend/opencl/schema/CLCache.fbs +++ b/source/backend/opencl/schema/CLCache.fbs @@ -29,18 +29,12 @@ table GemmInfo { paramInfo:[uint]; } -table PreParamInfo{ - preParamName:string; - preParamData:uint; -} - table BackendInfo{ mnnVersion:string; deviceName:string; programs:[Shader]; tunings:[Autotuning]; gemm:[GemmInfo]; - preParam:[PreParamInfo]; } table Cache { diff --git a/source/backend/opencl/schema/CLCache/Autotuning.py b/source/backend/opencl/schema/CLCache/Autotuning.py new file mode 100644 index 00000000..aca29e36 --- /dev/null +++ b/source/backend/opencl/schema/CLCache/Autotuning.py @@ -0,0 +1,141 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class Autotuning(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Autotuning() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAutotuning(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # Autotuning + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Autotuning + def Key(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Autotuning + def GloablSize(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Autotuning + def GloablSizeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # Autotuning + def GloablSizeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Autotuning + def GloablSizeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # Autotuning + def LocalSize(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Autotuning + def LocalSizeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # Autotuning + def LocalSizeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Autotuning + def LocalSizeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # Autotuning + def TimeCost(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + +def AutotuningStart(builder): + builder.StartObject(4) + +def Start(builder): + AutotuningStart(builder) + +def AutotuningAddKey(builder, key): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0) + +def AddKey(builder, key): + AutotuningAddKey(builder, key) + +def AutotuningAddGloablSize(builder, gloablSize): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(gloablSize), 0) + +def AddGloablSize(builder, gloablSize): + AutotuningAddGloablSize(builder, gloablSize) + +def AutotuningStartGloablSizeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartGloablSizeVector(builder, numElems): + return AutotuningStartGloablSizeVector(builder, numElems) + +def AutotuningAddLocalSize(builder, localSize): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(localSize), 0) + +def AddLocalSize(builder, localSize): + AutotuningAddLocalSize(builder, localSize) + +def AutotuningStartLocalSizeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartLocalSizeVector(builder, numElems): + return AutotuningStartLocalSizeVector(builder, numElems) + +def AutotuningAddTimeCost(builder, timeCost): + builder.PrependUint32Slot(3, timeCost, 0) + +def AddTimeCost(builder, timeCost): + AutotuningAddTimeCost(builder, timeCost) + +def AutotuningEnd(builder): + return builder.EndObject() + +def End(builder): + return AutotuningEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/BackendInfo.py b/source/backend/opencl/schema/CLCache/BackendInfo.py new file mode 100644 index 00000000..a039b0f8 --- /dev/null +++ b/source/backend/opencl/schema/CLCache/BackendInfo.py @@ -0,0 +1,174 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class BackendInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BackendInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBackendInfo(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # BackendInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BackendInfo + def MnnVersion(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # BackendInfo + def DeviceName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # BackendInfo + def Programs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.Shader import Shader + obj = Shader() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # BackendInfo + def ProgramsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # BackendInfo + def ProgramsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # BackendInfo + def Tunings(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.Autotuning import Autotuning + obj = Autotuning() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # BackendInfo + def TuningsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # BackendInfo + def TuningsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # BackendInfo + def Gemm(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.GemmInfo import GemmInfo + obj = GemmInfo() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # BackendInfo + def GemmLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # BackendInfo + def GemmIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + +def BackendInfoStart(builder): + builder.StartObject(5) + +def Start(builder): + BackendInfoStart(builder) + +def BackendInfoAddMnnVersion(builder, mnnVersion): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(mnnVersion), 0) + +def AddMnnVersion(builder, mnnVersion): + BackendInfoAddMnnVersion(builder, mnnVersion) + +def BackendInfoAddDeviceName(builder, deviceName): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(deviceName), 0) + +def AddDeviceName(builder, deviceName): + BackendInfoAddDeviceName(builder, deviceName) + +def BackendInfoAddPrograms(builder, programs): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(programs), 0) + +def AddPrograms(builder, programs): + BackendInfoAddPrograms(builder, programs) + +def BackendInfoStartProgramsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartProgramsVector(builder, numElems): + return BackendInfoStartProgramsVector(builder, numElems) + +def BackendInfoAddTunings(builder, tunings): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(tunings), 0) + +def AddTunings(builder, tunings): + BackendInfoAddTunings(builder, tunings) + +def BackendInfoStartTuningsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartTuningsVector(builder, numElems): + return BackendInfoStartTuningsVector(builder, numElems) + +def BackendInfoAddGemm(builder, gemm): + builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(gemm), 0) + +def AddGemm(builder, gemm): + BackendInfoAddGemm(builder, gemm) + +def BackendInfoStartGemmVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartGemmVector(builder, numElems): + return BackendInfoStartGemmVector(builder, numElems) + +def BackendInfoEnd(builder): + return builder.EndObject() + +def End(builder): + return BackendInfoEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/Cache.py b/source/backend/opencl/schema/CLCache/Cache.py new file mode 100644 index 00000000..1f314ebd --- /dev/null +++ b/source/backend/opencl/schema/CLCache/Cache.py @@ -0,0 +1,111 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class Cache(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Cache() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCache(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # Cache + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Cache + def Backends(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.BackendInfo import BackendInfo + obj = BackendInfo() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Cache + def BackendsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Cache + def BackendsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # Cache + def Tuned(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.OpInfo import OpInfo + obj = OpInfo() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Cache + def TunedLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Cache + def TunedIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def CacheStart(builder): + builder.StartObject(2) + +def Start(builder): + CacheStart(builder) + +def CacheAddBackends(builder, backends): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(backends), 0) + +def AddBackends(builder, backends): + CacheAddBackends(builder, backends) + +def CacheStartBackendsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartBackendsVector(builder, numElems): + return CacheStartBackendsVector(builder, numElems) + +def CacheAddTuned(builder, tuned): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(tuned), 0) + +def AddTuned(builder, tuned): + CacheAddTuned(builder, tuned) + +def CacheStartTunedVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartTunedVector(builder, numElems): + return CacheStartTunedVector(builder, numElems) + +def CacheEnd(builder): + return builder.EndObject() + +def End(builder): + return CacheEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/GemmInfo.py b/source/backend/opencl/schema/CLCache/GemmInfo.py new file mode 100644 index 00000000..28df382a --- /dev/null +++ b/source/backend/opencl/schema/CLCache/GemmInfo.py @@ -0,0 +1,115 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class GemmInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GemmInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGemmInfo(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # GemmInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # GemmInfo + def GemmSize(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # GemmInfo + def GemmSizeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # GemmInfo + def GemmSizeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # GemmInfo + def GemmSizeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # GemmInfo + def ParamInfo(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # GemmInfo + def ParamInfoAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # GemmInfo + def ParamInfoLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # GemmInfo + def ParamInfoIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def GemmInfoStart(builder): + builder.StartObject(2) + +def Start(builder): + GemmInfoStart(builder) + +def GemmInfoAddGemmSize(builder, gemmSize): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(gemmSize), 0) + +def AddGemmSize(builder, gemmSize): + GemmInfoAddGemmSize(builder, gemmSize) + +def GemmInfoStartGemmSizeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartGemmSizeVector(builder, numElems): + return GemmInfoStartGemmSizeVector(builder, numElems) + +def GemmInfoAddParamInfo(builder, paramInfo): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(paramInfo), 0) + +def AddParamInfo(builder, paramInfo): + GemmInfoAddParamInfo(builder, paramInfo) + +def GemmInfoStartParamInfoVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartParamInfoVector(builder, numElems): + return GemmInfoStartParamInfoVector(builder, numElems) + +def GemmInfoEnd(builder): + return builder.EndObject() + +def End(builder): + return GemmInfoEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/OpInfo.py b/source/backend/opencl/schema/CLCache/OpInfo.py new file mode 100644 index 00000000..f5282770 --- /dev/null +++ b/source/backend/opencl/schema/CLCache/OpInfo.py @@ -0,0 +1,137 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class OpInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = OpInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsOpInfo(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # OpInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # OpInfo + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # OpInfo + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # OpInfo + def Inputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.TensorInfo import TensorInfo + obj = TensorInfo() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # OpInfo + def InputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # OpInfo + def InputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # OpInfo + def Outputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from CLCache.TensorInfo import TensorInfo + obj = TensorInfo() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # OpInfo + def OutputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # OpInfo + def OutputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + +def OpInfoStart(builder): + builder.StartObject(4) + +def Start(builder): + OpInfoStart(builder) + +def OpInfoAddName(builder, name): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) + +def AddName(builder, name): + OpInfoAddName(builder, name) + +def OpInfoAddType(builder, type): + builder.PrependInt32Slot(1, type, 0) + +def AddType(builder, type): + OpInfoAddType(builder, type) + +def OpInfoAddInputs(builder, inputs): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0) + +def AddInputs(builder, inputs): + OpInfoAddInputs(builder, inputs) + +def OpInfoStartInputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartInputsVector(builder, numElems): + return OpInfoStartInputsVector(builder, numElems) + +def OpInfoAddOutputs(builder, outputs): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) + +def AddOutputs(builder, outputs): + OpInfoAddOutputs(builder, outputs) + +def OpInfoStartOutputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartOutputsVector(builder, numElems): + return OpInfoStartOutputsVector(builder, numElems) + +def OpInfoEnd(builder): + return builder.EndObject() + +def End(builder): + return OpInfoEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/Shader.py b/source/backend/opencl/schema/CLCache/Shader.py new file mode 100644 index 00000000..b7778a0e --- /dev/null +++ b/source/backend/opencl/schema/CLCache/Shader.py @@ -0,0 +1,115 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class Shader(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Shader() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsShader(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # Shader + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Shader + def Buffer(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # Shader + def BufferAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # Shader + def BufferLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Shader + def BufferIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # Shader + def Program(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Shader + def Kernel(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Shader + def BuildInfo(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + +def ShaderStart(builder): + builder.StartObject(4) + +def Start(builder): + ShaderStart(builder) + +def ShaderAddBuffer(builder, buffer): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(buffer), 0) + +def AddBuffer(builder, buffer): + ShaderAddBuffer(builder, buffer) + +def ShaderStartBufferVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + +def StartBufferVector(builder, numElems): + return ShaderStartBufferVector(builder, numElems) + +def ShaderAddProgram(builder, program): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(program), 0) + +def AddProgram(builder, program): + ShaderAddProgram(builder, program) + +def ShaderAddKernel(builder, kernel): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0) + +def AddKernel(builder, kernel): + ShaderAddKernel(builder, kernel) + +def ShaderAddBuildInfo(builder, buildInfo): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(buildInfo), 0) + +def AddBuildInfo(builder, buildInfo): + ShaderAddBuildInfo(builder, buildInfo) + +def ShaderEnd(builder): + return builder.EndObject() + +def End(builder): + return ShaderEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/TensorInfo.py b/source/backend/opencl/schema/CLCache/TensorInfo.py new file mode 100644 index 00000000..a1dbc825 --- /dev/null +++ b/source/backend/opencl/schema/CLCache/TensorInfo.py @@ -0,0 +1,76 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: CLCache + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TensorInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TensorInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTensorInfo(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # TensorInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TensorInfo + def Shape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TensorInfo + def ShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TensorInfo + def ShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TensorInfo + def ShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + +def TensorInfoStart(builder): + builder.StartObject(1) + +def Start(builder): + TensorInfoStart(builder) + +def TensorInfoAddShape(builder, shape): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) + +def AddShape(builder, shape): + TensorInfoAddShape(builder, shape) + +def TensorInfoStartShapeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def StartShapeVector(builder, numElems): + return TensorInfoStartShapeVector(builder, numElems) + +def TensorInfoEnd(builder): + return builder.EndObject() + +def End(builder): + return TensorInfoEnd(builder) diff --git a/source/backend/opencl/schema/CLCache/__init__.py b/source/backend/opencl/schema/CLCache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source/backend/opencl/schema/current/CLCache_generated.h b/source/backend/opencl/schema/current/CLCache_generated.h index f172e5d1..5be4217b 100644 --- a/source/backend/opencl/schema/current/CLCache_generated.h +++ b/source/backend/opencl/schema/current/CLCache_generated.h @@ -23,9 +23,6 @@ struct AutotuningT; struct GemmInfo; struct GemmInfoT; -struct PreParamInfo; -struct PreParamInfoT; - struct BackendInfo; struct BackendInfoT; @@ -42,8 +39,6 @@ inline const flatbuffers::TypeTable *AutotuningTypeTable(); inline const flatbuffers::TypeTable *GemmInfoTypeTable(); -inline const flatbuffers::TypeTable *PreParamInfoTypeTable(); - inline const flatbuffers::TypeTable *BackendInfoTypeTable(); inline const flatbuffers::TypeTable *CacheTypeTable(); @@ -430,71 +425,6 @@ inline flatbuffers::Offset CreateGemmInfo( flatbuffers::Offset CreateGemmInfo(flatbuffers::FlatBufferBuilder &_fbb, const GemmInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct PreParamInfoT : public flatbuffers::NativeTable { - typedef PreParamInfo TableType; - std::string preParamName; - uint32_t preParamData; - PreParamInfoT() - : preParamData(0) { - } -}; - -struct PreParamInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef PreParamInfoT NativeTableType; - static const flatbuffers::TypeTable *MiniReflectTypeTable() { - return PreParamInfoTypeTable(); - } - const flatbuffers::String *preParamName() const { - return GetPointer(4); - } - uint32_t preParamData() const { - return GetField(6, 0); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, 4) && - verifier.VerifyString(preParamName()) && - VerifyField(verifier, 6) && - verifier.EndTable(); - } - PreParamInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(PreParamInfoT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -}; - -struct PreParamInfoBuilder { - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_preParamName(flatbuffers::Offset preParamName) { - fbb_.AddOffset(4, preParamName); - } - void add_preParamData(uint32_t preParamData) { - fbb_.AddElement(6, preParamData, 0); - } - explicit PreParamInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - PreParamInfoBuilder &operator=(const PreParamInfoBuilder &); - flatbuffers::Offset Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); - return o; - } -}; - -inline flatbuffers::Offset CreatePreParamInfo( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset preParamName = 0, - uint32_t preParamData = 0) { - PreParamInfoBuilder builder_(_fbb); - builder_.add_preParamData(preParamData); - builder_.add_preParamName(preParamName); - return builder_.Finish(); -} - -flatbuffers::Offset CreatePreParamInfo(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); - struct BackendInfoT : public flatbuffers::NativeTable { typedef BackendInfo TableType; std::string mnnVersion; @@ -502,7 +432,6 @@ struct BackendInfoT : public flatbuffers::NativeTable { std::vector> programs; std::vector> tunings; std::vector> gemm; - std::vector> preParam; BackendInfoT() { } }; @@ -527,9 +456,6 @@ struct BackendInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *gemm() const { return GetPointer> *>(12); } - const flatbuffers::Vector> *preParam() const { - return GetPointer> *>(14); - } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, 4) && @@ -545,9 +471,6 @@ struct BackendInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, 12) && verifier.VerifyVector(gemm()) && verifier.VerifyVectorOfTables(gemm()) && - VerifyOffset(verifier, 14) && - verifier.VerifyVector(preParam()) && - verifier.VerifyVectorOfTables(preParam()) && verifier.EndTable(); } BackendInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -573,9 +496,6 @@ struct BackendInfoBuilder { void add_gemm(flatbuffers::Offset>> gemm) { fbb_.AddOffset(12, gemm); } - void add_preParam(flatbuffers::Offset>> preParam) { - fbb_.AddOffset(14, preParam); - } explicit BackendInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -594,10 +514,8 @@ inline flatbuffers::Offset CreateBackendInfo( flatbuffers::Offset deviceName = 0, flatbuffers::Offset>> programs = 0, flatbuffers::Offset>> tunings = 0, - flatbuffers::Offset>> gemm = 0, - flatbuffers::Offset>> preParam = 0) { + flatbuffers::Offset>> gemm = 0) { BackendInfoBuilder builder_(_fbb); - builder_.add_preParam(preParam); builder_.add_gemm(gemm); builder_.add_tunings(tunings); builder_.add_programs(programs); @@ -835,35 +753,6 @@ inline flatbuffers::Offset CreateGemmInfo(flatbuffers::FlatBufferBuild _paramInfo); } -inline PreParamInfoT *PreParamInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new PreParamInfoT(); - UnPackTo(_o, _resolver); - return _o; -} - -inline void PreParamInfo::UnPackTo(PreParamInfoT *_o, const flatbuffers::resolver_function_t *_resolver) const { - (void)_o; - (void)_resolver; - { auto _e = preParamName(); if (_e) _o->preParamName = _e->str(); }; - { auto _e = preParamData(); _o->preParamData = _e; }; -} - -inline flatbuffers::Offset PreParamInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreatePreParamInfo(_fbb, _o, _rehasher); -} - -inline flatbuffers::Offset CreatePreParamInfo(flatbuffers::FlatBufferBuilder &_fbb, const PreParamInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher) { - (void)_rehasher; - (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PreParamInfoT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _preParamName = _o->preParamName.empty() ? 0 : _fbb.CreateString(_o->preParamName); - auto _preParamData = _o->preParamData; - return CLCache::CreatePreParamInfo( - _fbb, - _preParamName, - _preParamData); -} - inline BackendInfoT *BackendInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BackendInfoT(); UnPackTo(_o, _resolver); @@ -878,7 +767,6 @@ inline void BackendInfo::UnPackTo(BackendInfoT *_o, const flatbuffers::resolver_ { auto _e = programs(); if (_e) { _o->programs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->programs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; { auto _e = tunings(); if (_e) { _o->tunings.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tunings[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; { auto _e = gemm(); if (_e) { _o->gemm.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->gemm[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; - { auto _e = preParam(); if (_e) { _o->preParam.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->preParam[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; } inline flatbuffers::Offset BackendInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BackendInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -894,15 +782,13 @@ inline flatbuffers::Offset CreateBackendInfo(flatbuffers::FlatBuffe auto _programs = _o->programs.size() ? _fbb.CreateVector> (_o->programs.size(), [](size_t i, _VectorArgs *__va) { return CreateShader(*__va->__fbb, __va->__o->programs[i].get(), __va->__rehasher); }, &_va ) : 0; auto _tunings = _o->tunings.size() ? _fbb.CreateVector> (_o->tunings.size(), [](size_t i, _VectorArgs *__va) { return CreateAutotuning(*__va->__fbb, __va->__o->tunings[i].get(), __va->__rehasher); }, &_va ) : 0; auto _gemm = _o->gemm.size() ? _fbb.CreateVector> (_o->gemm.size(), [](size_t i, _VectorArgs *__va) { return CreateGemmInfo(*__va->__fbb, __va->__o->gemm[i].get(), __va->__rehasher); }, &_va ) : 0; - auto _preParam = _o->preParam.size() ? _fbb.CreateVector> (_o->preParam.size(), [](size_t i, _VectorArgs *__va) { return CreatePreParamInfo(*__va->__fbb, __va->__o->preParam[i].get(), __va->__rehasher); }, &_va ) : 0; return CLCache::CreateBackendInfo( _fbb, _mnnVersion, _deviceName, _programs, _tunings, - _gemm, - _preParam); + _gemm); } inline CacheT *Cache::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -1022,46 +908,28 @@ inline const flatbuffers::TypeTable *GemmInfoTypeTable() { return &tt; } -inline const flatbuffers::TypeTable *PreParamInfoTypeTable() { - static const flatbuffers::TypeCode type_codes[] = { - { flatbuffers::ET_STRING, 0, -1 }, - { flatbuffers::ET_UINT, 0, -1 } - }; - static const char * const names[] = { - "preParamName", - "preParamData" - }; - static const flatbuffers::TypeTable tt = { - flatbuffers::ST_TABLE, 2, type_codes, nullptr, nullptr, names - }; - return &tt; -} - inline const flatbuffers::TypeTable *BackendInfoTypeTable() { static const flatbuffers::TypeCode type_codes[] = { { flatbuffers::ET_STRING, 0, -1 }, { flatbuffers::ET_STRING, 0, -1 }, { flatbuffers::ET_SEQUENCE, 1, 0 }, { flatbuffers::ET_SEQUENCE, 1, 1 }, - { flatbuffers::ET_SEQUENCE, 1, 2 }, - { flatbuffers::ET_SEQUENCE, 1, 3 } + { flatbuffers::ET_SEQUENCE, 1, 2 } }; static const flatbuffers::TypeFunction type_refs[] = { ShaderTypeTable, AutotuningTypeTable, - GemmInfoTypeTable, - PreParamInfoTypeTable + GemmInfoTypeTable }; static const char * const names[] = { "mnnVersion", "deviceName", "programs", "tunings", - "gemm", - "preParam" + "gemm" }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_TABLE, 6, type_codes, type_refs, nullptr, names + flatbuffers::ST_TABLE, 5, type_codes, type_refs, nullptr, names }; return &tt; } diff --git a/source/backend/opencl/schema/import_cache.py b/source/backend/opencl/schema/import_cache.py new file mode 100644 index 00000000..ee367934 --- /dev/null +++ b/source/backend/opencl/schema/import_cache.py @@ -0,0 +1,80 @@ +import flatbuffers +from CLCache.Cache import Cache +from CLCache.BackendInfo import BackendInfo + +def generate_cpp_header(buffer): + cache = Cache.GetRootAs(buffer, 0) + backends_len = cache.BackendsLength() + + opencl_tune_map = {} + + # 读取 BackendInfo 信息 + for i in range(backends_len): + backend_info = cache.Backends(i) + device_name = backend_info.DeviceName().decode("utf-8") + + autotuning_map = {} + tunings_len = backend_info.TuningsLength() + + for j in range(tunings_len): + tuning = backend_info.Tunings(j) + key = tuning.Key().decode("utf-8") + global_size = list(tuning.GloablSize(j) for j in range(tuning.GloablSizeLength())) + local_size =list(tuning.LocalSize(j) for j in range(tuning.LocalSizeLength())) + + if key not in autotuning_map: + autotuning_map[key] = [] + + autotuning_map[key].append((global_size, local_size)) + + gemm_len = backend_info.GemmLength() + + for j in range(gemm_len): + gemm = backend_info.Gemm(j) + key = 'Xgemm_tune' + gemm_size = list(gemm.GemmSize(j) for j in range(gemm.GemmSizeLength())) + param_info =list(gemm.ParamInfo(j) for j in range(gemm.ParamInfoLength())) + + if key not in autotuning_map: + autotuning_map[key] = [] + + autotuning_map[key].append((gemm_size, param_info)) + + opencl_tune_map[device_name] = autotuning_map + + # 生成 C++ 代码字符串 + cpp_code = "const std::map, std::vector>>>> OpenCLTuneMap = {\n" + + for device, tuning_map in opencl_tune_map.items(): + cpp_code += f' {{"{device}", {{\n' + for key, size_pairs in tuning_map.items(): + cpp_code += f' {{"{key}", {{\n' + for sizes in size_pairs: + cpp_code += f' {{{{ {", ".join(map(str, sizes[0]))} }}, {{ {", ".join(map(str, sizes[1]))} }}}},\n' + cpp_code += " }},\n" + cpp_code += " }},\n" + + cpp_code += "};\n" + + return cpp_code + + +if __name__ == '__main__': + import sys + if len(sys.argv) != 2: + print("Usage: python import_cache.py ") + print("Example: python merge_cache.py mnn_cachefile.bin") + sys.exit(1) + + with open(sys.argv[1], 'rb') as f: + buffer = f.read() + + cpp_header_code = generate_cpp_header(buffer) + + # 将结果保存到头文件中 + with open('OpenCLTuneMap.hpp', 'w') as header_file: + header_file.write("#include \n#include \n#include \n\nnamespace MNN { \n") + header_file.write(cpp_header_code) + header_file.write("\n}\n") + + print("C++ header file generated.") \ No newline at end of file diff --git a/source/backend/opencl/schema/merge_cache.py b/source/backend/opencl/schema/merge_cache.py new file mode 100644 index 00000000..f02c3e76 --- /dev/null +++ b/source/backend/opencl/schema/merge_cache.py @@ -0,0 +1,160 @@ +import flatbuffers +from CLCache import Cache, BackendInfo, Autotuning, GemmInfo + +def load_backend_infos(file_path): + with open(file_path, 'rb') as f: + buf = bytearray(f.read()) + cache = Cache.Cache.GetRootAs(buf, 0) + backends = [] + for i in range(cache.BackendsLength()): + backend = cache.Backends(i) + backends.append(backend) + return backends + +def load_tune_infos(backends): + original_map = {} + for backend in backends: + mnn_version = backend.MnnVersion() + device_name = backend.DeviceName() + tunings = {} + for i in range(backend.TuningsLength()): + tune = backend.Tunings(i) + key = tune.Key() + global_size = [tune.GloablSize(j) for j in range(tune.GloablSizeLength())] + local_size = [tune.LocalSize(j) for j in range(tune.LocalSizeLength())] + cost_time = tune.TimeCost() + tunings[(key, tuple(global_size))] = (local_size, cost_time) + + #gemm tune info + for i in range(backend.GemmLength()): + tune = backend.Gemm(i) + key = 'Xgemm_tune' + gemm_size = [tune.GemmSize(j) for j in range(tune.GemmSizeLength())] + param_info = [tune.ParamInfo(j) for j in range(tune.ParamInfoLength())] + tunings[(key, tuple(gemm_size))] = (param_info, 0) + original_map[(mnn_version, device_name)] = tunings + return original_map + +def create_backend_info(new_backends, original_backends): + + original_map = load_tune_infos(original_backends) + new_map = load_tune_infos(new_backends) + + for ver_dev in new_map: + if ver_dev in original_map: + new_tune = new_map[ver_dev] + original_tune = original_map[ver_dev] + for key in new_tune: + if key not in original_tune: + original_tune[key] = new_tune[key] + else: + original_map[ver_dev] = new_map[ver_dev] + + return original_map + +def build_cache(nested_dict): + """将嵌套字典转换为 FlatBuffers 的 Cache 结构""" + builder = flatbuffers.Builder() + + # ====================== 构建 BackendInfo 列表 ====================== + backend_offsets = [] + for (mnn_ver, device_name), autotune_dict in nested_dict.items(): + # 构建字符串 + mnn_ver_offset = builder.CreateString(mnn_ver) + device_name_offset = builder.CreateString(device_name) + + # 构建 Autotuning 条目 + tuning_offsets = [] + gemm_offsets = [] + for (key, global_size), (local_size, time_cost) in autotune_dict.items(): + if key == 'Xgemm_tune': + # 构建 GemmSize 向量 (倒序填充) + GemmInfo.GemmInfoStartGemmSizeVector(builder, len(global_size)) + for n in reversed(global_size): + builder.PrependUint32(n) + global_size_offset = builder.EndVector() + + # 构建 ParamInfo 向量 (倒序填充) + GemmInfo.GemmInfoStartParamInfoVector(builder, len(local_size)) + for n in reversed(local_size): + builder.PrependUint32(n) + local_size_offset = builder.EndVector() + + # 构建 Autotuning 对象 + GemmInfo.GemmInfoStart(builder) + GemmInfo.GemmInfoAddGemmSize(builder, global_size_offset) + GemmInfo.GemmInfoAddParamInfo(builder, local_size_offset) + gemm_offsets.append(GemmInfo.GemmInfoEnd(builder)) + else: + # 构建字符串 + key_offset = builder.CreateString(key) + + # 构建 globalSize 向量 (倒序填充) + Autotuning.AutotuningStartGloablSizeVector(builder, len(global_size)) + for n in reversed(global_size): + builder.PrependUint32(n) + global_size_offset = builder.EndVector() + + # 构建 localSize 向量 (倒序填充) + Autotuning.AutotuningStartLocalSizeVector(builder, len(local_size)) + for n in reversed(local_size): + builder.PrependUint32(n) + local_size_offset = builder.EndVector() + + # 构建 Autotuning 对象 + Autotuning.AutotuningStart(builder) + Autotuning.AutotuningAddKey(builder, key_offset) + Autotuning.AutotuningAddGloablSize(builder, global_size_offset) + Autotuning.AutotuningAddLocalSize(builder, local_size_offset) + Autotuning.AutotuningAddTimeCost(builder, time_cost) + Autotuning.AutotuningAddTimeCost(builder, 0) + tuning_offsets.append(Autotuning.AutotuningEnd(builder)) + + # 构建 tunings 向量 + BackendInfo.BackendInfoStartTuningsVector(builder, len(tuning_offsets)) + for offset in reversed(tuning_offsets): + builder.PrependUOffsetTRelative(offset) + tunings_offset = builder.EndVector() + + # 构建 gemm 向量 + BackendInfo.BackendInfoStartGemmVector(builder, len(gemm_offsets)) + for offset in reversed(gemm_offsets): + builder.PrependUOffsetTRelative(offset) + gemm_offsets = builder.EndVector() + + # 构建 BackendInfo + BackendInfo.BackendInfoStart(builder) + BackendInfo.BackendInfoAddMnnVersion(builder, mnn_ver_offset) + BackendInfo.BackendInfoAddDeviceName(builder, device_name_offset) + BackendInfo.BackendInfoAddTunings(builder, tunings_offset) + BackendInfo.BackendInfoAddGemm(builder, gemm_offsets) + backend_offsets.append(BackendInfo.BackendInfoEnd(builder)) + + # ====================== 构建最终 Cache ====================== + # 构建 backends 向量 + Cache.CacheStartBackendsVector(builder, len(backend_offsets)) + for offset in reversed(backend_offsets): + builder.PrependUOffsetTRelative(offset) + backends_offset = builder.EndVector() + + # 构建根对象 + Cache.CacheStart(builder) + Cache.CacheAddBackends(builder, backends_offset) + cache = Cache.CacheEnd(builder) + + builder.Finish(cache) + return builder.Output() + +if __name__ == '__main__': + import sys + if len(sys.argv) != 4: + print("Usage: python merge_cache.py ") + print("Example: python merge_cache.py mnn_cachefile.bin mnn_cachefile_total.bin new_cache.bin") + sys.exit(1) + original_backends = load_backend_infos(sys.argv[1]) + new_backends = load_backend_infos(sys.argv[2]) + original_map = create_backend_info(new_backends, original_backends) + #print(original_map) + binary_data = build_cache(original_map) + with open(sys.argv[3], "wb") as f: + f.write(binary_data) \ No newline at end of file diff --git a/source/shape/ShapeDepthToSpace.cpp b/source/shape/ShapeDepthToSpace.cpp index eca65cfe..db94b5d6 100644 --- a/source/shape/ShapeDepthToSpace.cpp +++ b/source/shape/ShapeDepthToSpace.cpp @@ -19,8 +19,6 @@ class DepthToSpaceSizeComputer : public SizeComputer { MNN_ASSERT(outputs.size() == 1); MNN_ASSERT(inputs[0]->buffer().dimensions == 4); - // here only implement NHWC - // TODO: implement NC4HW4 const int blockSize = op->main_as_DepthSpaceParam()->blockSize(); MNN_ASSERT(blockSize >= 1); auto format = TensorUtils::getDescribe(inputs[0])->dimensionFormat; diff --git a/source/shape/ShapeSpaceToDepth.cpp b/source/shape/ShapeSpaceToDepth.cpp index a8a78edb..94360b37 100644 --- a/source/shape/ShapeSpaceToDepth.cpp +++ b/source/shape/ShapeSpaceToDepth.cpp @@ -19,10 +19,8 @@ class SpaceToDepthSizeComputer : public SizeComputer { MNN_ASSERT(outputs.size() == 1); MNN_ASSERT(inputs[0]->buffer().dimensions == 4); - // here only implement NHWC - // TODO: implement NC4HW4 const int blockSize = op->main_as_DepthSpaceParam()->blockSize(); - MNN_ASSERT(blockSize > 1); + MNN_ASSERT(blockSize >= 1); auto& ib = inputs[0]->buffer(); auto& ob = outputs[0]->buffer(); diff --git a/tools/cpp/CMakeLists.txt b/tools/cpp/CMakeLists.txt index c10d56db..31df54be 100644 --- a/tools/cpp/CMakeLists.txt +++ b/tools/cpp/CMakeLists.txt @@ -17,6 +17,10 @@ IF(CMAKE_SYSTEM_NAME MATCHES "^Android") target_link_libraries(GpuInterTest.out android) list(APPEND MNN_CPP_TOOLS GpuInterTest.out) ENDIF() + +add_executable(OpenCLProgramBuildTest.out ${CMAKE_CURRENT_LIST_DIR}/OpenCLProgramBuildTest.cpp ) +list(APPEND MNN_CPP_TOOLS OpenCLProgramBuildTest.out) + add_executable(SequenceModuleTest.out ${CMAKE_CURRENT_LIST_DIR}/SequenceModuleTest.cpp) list(APPEND MNN_CPP_TOOLS SequenceModuleTest.out) diff --git a/tools/cpp/OpenCLProgramBuildTest.cpp b/tools/cpp/OpenCLProgramBuildTest.cpp new file mode 100644 index 00000000..2165e3e8 --- /dev/null +++ b/tools/cpp/OpenCLProgramBuildTest.cpp @@ -0,0 +1,244 @@ +// +// OpenCLProgramBuildTest.cpp +// MNN +// +// Created by MNN on 2025/5/15. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#include +#include +#include +#include "CL/cl.h" +#ifdef _WIN32 +#include +#include +#else +#include +#endif + +using clGetPlatformIDsFunc = cl_int (CL_API_CALL *)(cl_uint, cl_platform_id *, cl_uint *); +using clBuildProgramFunc = cl_int (CL_API_CALL *)(cl_program, cl_uint, const cl_device_id *, const char *, void (CL_CALLBACK *pfn_notify)(cl_program, void *), void *); +using clCreateProgramWithSourceFunc = cl_program (CL_API_CALL *)(cl_context, cl_uint, const char **, const size_t *, cl_int *); +using clGetProgramBuildInfoFunc = cl_int (CL_API_CALL *)(cl_program, cl_device_id, cl_program_build_info, size_t, void *, size_t *); +using clCreateContextFunc = cl_context (CL_API_CALL *)(const cl_context_properties *, cl_uint, const cl_device_id *, + void(CL_CALLBACK *)( // NOLINT(readability/casting) + const char *, const void *, size_t, void *), + void *, cl_int *); +using clGetDeviceIDsFunc = cl_int (CL_API_CALL *)(cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *); +using clGetDeviceInfoFunc = cl_int (CL_API_CALL *)(cl_device_id, cl_device_info, size_t, void *, size_t *); +using clReleaseProgramFunc = cl_int (CL_API_CALL *)(cl_program program); +using clReleaseContextFunc = cl_int (CL_API_CALL *)(cl_context); +using clReleaseDeviceFunc = cl_int (CL_API_CALL *)(cl_device_id); + + +class OpenCLProgramTest { +public: + OpenCLProgramTest(){ + static const std::vector gOpencl_library_paths = { + + #if defined(__APPLE__) || defined(__MACOSX) + "libOpenCL.so", "/System/Library/Frameworks/OpenCL.framework/OpenCL" + #elif defined(__OHOS__) + "/vendor/lib64/chipsetsdk/libhvgr_v200.so", + "/vendor/lib64/chipsetsdk/libGLES_mali.so", + "/system/lib64/libGLES_mali.so", + "libGLES_mali.so", + "/vendor/lib64/chipsetsdk/libEGI_imp1.so", + #elif defined(__ANDROID__) + "libOpenCL.so", + "libGLES_mali.so", + "libmali.so", + "libOpenCL-pixel.so", + #if defined(__aarch64__) + // Qualcomm Adreno + "/system/vendor/lib64/libOpenCL.so", + "/system/lib64/libOpenCL.so", + // Mali + "/system/vendor/lib64/egl/libGLES_mali.so", + "/system/lib64/egl/libGLES_mali.so", + #else + // Qualcomm Adreno + "/system/vendor/lib/libOpenCL.so", "/system/lib/libOpenCL.so", + // Mali + "/system/vendor/lib/egl/libGLES_mali.so", "/system/lib/egl/libGLES_mali.so", + // other + "/system/vendor/lib/libPVROCL.so", "/data/data/org.pocl.libs/files/lib/libpocl.so" + #endif + #elif defined(__linux__) + "/usr/lib/libOpenCL.so", + "/usr/local/lib/libOpenCL.so", + "/usr/local/lib/libpocl.so", + "/usr/lib64/libOpenCL.so", + "/usr/lib32/libOpenCL.so", + "libOpenCL.so" + #elif defined(_WIN64) + "C:/Windows/System32/OpenCL.dll", + "C:/Windows/SysWOW64/OpenCL.dll" + #elif defined(_WIN32) + "C:/Windows/SysWOW64/OpenCL.dll", + "C:/Windows/System32/OpenCL.dll" + #endif + }; + + for (const auto &opencl_lib : gOpencl_library_paths) { + if (LoadLibraryFromPath(opencl_lib)) { + mIsSupportAvailable = true; + } + } + if(mIsSupportAvailable){ + cl_int err; + err = clGetPlatformIDs(1, &platform, NULL); + if (err != CL_SUCCESS) { + printf("Failed to get platform ID err = %d\n", err); + return ; + } + err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, NULL); + if (err != CL_SUCCESS) { + printf("Failed to get device ID err = %d\n", err); + return; + } + + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + if (!context || err != CL_SUCCESS) { + printf("Failed to create context err = %d\n", err); + return; + } + } + } + bool LoadLibraryFromPath(const std::string &library_path){ + #if defined(_WIN32) + handle_ = LoadLibraryA(library_path.c_str()); + if (handle_ == nullptr) { + return false; + } + #define MNN_LOAD_FUNCTION_PTR(func_name) func_name = reinterpret_cast(GetProcAddress(static_cast(handle_), #func_name)); + #else + handle_ = dlopen(library_path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (handle_ == nullptr) { + return false; + } + #define MNN_LOAD_FUNCTION_PTR(func_name) func_name = reinterpret_cast(dlsym(handle_, #func_name)); + #endif + MNN_LOAD_FUNCTION_PTR(clGetPlatformIDs); + MNN_LOAD_FUNCTION_PTR(clBuildProgram); + MNN_LOAD_FUNCTION_PTR(clCreateProgramWithSource); + MNN_LOAD_FUNCTION_PTR(clGetProgramBuildInfo); + MNN_LOAD_FUNCTION_PTR(clCreateContext); + MNN_LOAD_FUNCTION_PTR(clGetDeviceIDs); + MNN_LOAD_FUNCTION_PTR(clGetDeviceInfo); + MNN_LOAD_FUNCTION_PTR(clReleaseProgram); + MNN_LOAD_FUNCTION_PTR(clReleaseContext); + MNN_LOAD_FUNCTION_PTR(clReleaseDevice); + + return true; + } + bool TestProgram(const std::vector options){ + cl_int err; + FILE* file = fopen("kernel.cl", "r"); + if (!file) { + printf("Failed to open kernel file: kernel.cl\n"); + return false; + } + + fseek(file, 0, SEEK_END); + size_t fileSize = ftell(file); + rewind(file); + + char* source = (char*)malloc(fileSize + 1); + if (!source) { + fclose(file); + printf("Memory allocation failed for kernel source\n"); + return false; + } + + fread(source, sizeof(char), fileSize, file); + source[fileSize] = '\0'; + fclose(file); + + // test program + const char *code = source; + cl_program program = clCreateProgramWithSource(context, 1, &code, &fileSize, &err); + if (!program || err != CL_SUCCESS) { + printf("Failed to create program from source\n"); + return false; + } + for(int i = 0; i < options.size(); ++i){ + err = clBuildProgram(program, 1, &device, options[i].c_str(), NULL, NULL); + if (err != CL_SUCCESS) { + size_t logSize; + clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &logSize); + char *buildLog = (char*)malloc(logSize); + clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, logSize, buildLog, NULL); + printf("Program build log: "); + for (int i = 0; i < logSize; i++) { + printf("%c", buildLog[i]); + } + clReleaseProgram(program); + free(buildLog); + return false; + } + } + + clReleaseProgram(program); + free(source); + return true; + } + bool mIsSupportAvailable = false; + ~OpenCLProgramTest(){ + if(mIsSupportAvailable){ + clReleaseDevice(device); + clReleaseContext(context); + } + if (handle_ != nullptr) { +#if defined(_WIN32) + FreeLibrary(static_cast(handle_)); +#else + dlclose(handle_); +#endif + } + } +private: + void *handle_ = nullptr; + clGetPlatformIDsFunc clGetPlatformIDs; + clBuildProgramFunc clBuildProgram; + clCreateProgramWithSourceFunc clCreateProgramWithSource; + clGetProgramBuildInfoFunc clGetProgramBuildInfo; + clCreateContextFunc clCreateContext; + clGetDeviceIDsFunc clGetDeviceIDs; + clGetDeviceInfoFunc clGetDeviceInfo; + clReleaseProgramFunc clReleaseProgram; + clReleaseContextFunc clReleaseContext; + clReleaseDeviceFunc clReleaseDevice; + cl_platform_id platform; + cl_device_id device; + cl_context context; +}; + +int main(int argc, char *argv[]) { + std::string filename; + if(argc > 1){ + filename = argv[1]; + } + std::vector options; + std::fstream file("option.txt"); + if(file.is_open()){ + std::string line; + while (getline(file, line)) { // 按行读取文件内容并输出 + options.push_back(line); + } + file.close(); + } + printf("test filename is %s\n", filename.c_str()); + OpenCLProgramTest BuildTest; + if(BuildTest.mIsSupportAvailable){ + if(BuildTest.TestProgram(options)){ + return 0; + } + }else{ + printf("OpenCL init fail\n"); + return -1; + } + return 0; +} + diff --git a/tools/script/opencl_kernel_check.py b/tools/script/opencl_kernel_check.py new file mode 100644 index 00000000..4b413d96 --- /dev/null +++ b/tools/script/opencl_kernel_check.py @@ -0,0 +1,205 @@ +import sys +import os +import re +import itertools + +def run_cmd(args): + from subprocess import Popen, PIPE, STDOUT + stdout, _ = Popen(args, stdout=PIPE, stderr=STDOUT).communicate() + return stdout.decode('utf-8') + +def extract_macros(file_content): + """提取宏定义""" + macros = {} + macros_num = {} + ifdef_pattern = re.compile(r'#(ifdef)\s+(\w+)') + ifndef_pattern = re.compile(r'#(ifndef)\s+(\w+)') + if_pattern = re.compile(r'#(if)\s+(\w+)') + elif_pattern = re.compile(r'#(elif)\s+(\w+)') + defined_pattern = re.compile(r'(defined)\s+(\w+)') + define_pattern = re.compile(r'#(define)\s+(\w+)') + for match in ifdef_pattern.finditer(file_content): + macro_type, macro_name = match.groups() + if "LOCAL_SIZE" in macro_name: + macros_num[macro_name] = {1, 2, 3, 4, 16} + else: + macros[macro_name] = None + + for match in ifndef_pattern.finditer(file_content): + macro_type, macro_name = match.groups() + if "LOCAL_SIZE" in macro_name: + macros_num[macro_name] = {1, 2, 3, 4, 16} + else: + macros[macro_name] = None + + for match in if_pattern.finditer(file_content): + macro_type, macro_name = match.groups() + if macro_name != "defined": + macros_num[macro_name] = {1, 2, 3, 4, 8} + + for match in elif_pattern.finditer(file_content): + macro_type, macro_name = match.groups() + if macro_name != "defined": + macros_num[macro_name] = {1, 2, 3, 4, 8} + + for match in defined_pattern.finditer(file_content): + macro_type, macro_name = match.groups() + macros[macro_name] = None + + for match in define_pattern.finditer(file_content): + macro_type, macro_name = match.groups() + if macro_name in macros: + del macros[macro_name] + if macro_name in macros_num: + del macros_num[macro_name] + + if "MNN_SUPPORT_FP16" in macros: + del macros["MNN_SUPPORT_FP16"] + + #for macro_name, macro_value in macros.items(): + # Replace macro value + #print(f"macro_name {macro_name} macro_value {macro_value}") + return [macros_num, macros] + +def compile_with_macros(macros_all, operator_macro, extra_macro, filename, test_for_android): + """ + Tries to compile the kernel given various macro values + """ + macros_num = macros_all[0] + macros = macros_all[1] + float_option = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT=convert_float -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT3=convert_float3 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT=convert_float -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT3=convert_float3 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16" + float_option += " -DINPUT_TYPE_I=float -DINPUT_TYPE_I4=float4 -DINPUT_TYPE=float -DINPUT_TYPE4=float4 -DINPUT_TYPE16=float16 -DRI_DATA=read_imagef -DOUTPUT_TYPE_I=float -DOUTPUT_TYPE_I4=float4 -DCONVERT_OUTPUT_I4=convert_float4 -DOUTPUT_TYPE=float -DOUTPUT_TYPE4=float4 -DOUTPUT_TYPE16=float16 -DCONVERT_OUTPUT4=convert_float4 -DCONVERT_OUTPUT16=convert_float16 -DWI_DATA=write_imagef" + if filename in extra_macro: + float_option += extra_macro[filename] + keys = list(macros.keys()) + + # 使用 itertools.product 生成所有可能的 0 和 1 的组合 + combinations = list(itertools.product([0, 1], repeat=len(keys))) + + options_normal = [] + # 获取普通的宏定义 + for combination in combinations: + option_normal = float_option + macros_out = dict(zip(keys, combination)) + for macro_name, macro_value in macros_out.items(): + if macro_value == 1: + option_normal += f" -D{macro_name}={macro_value} " + options_normal.append(option_normal) + + options_num_normal = [] + # 获取有多种取值的宏 + if len(macros_num) > 0 : + option_num = "" + for i in {1, 2, 3, 4, 8} : + for macro_name in macros_num: + option_num = f" -D{macro_name}={i} " + for option_normal in options_normal: + options_num_normal.append(option_normal + option_num) + else: + options_num_normal = options_normal + + options = [] + # 获取OPERATOR的宏, 只需要验证第一个OPERATOR宏与其他宏的各种组合,其他的可以只验证一种组合 + if len(operator_macro) > 0 : + has_combine = False + for op in operator_macro: + option_operator = f" -DOPERATOR={op} " + if has_combine is True: + options.append(options_num_normal[0] + option_operator) + else: + for option_num_normal in options_num_normal: + options.append(option_num_normal + option_operator) + has_combine = True + else: + options = options_num_normal + + + with open('option.txt', 'w') as outfile: + for option in options: + outfile.write(option + '\n') + + if test_for_android == 1: + run_cmd(['adb', 'push', 'kernel.cl', '/data/local/tmp/MNN']) + run_cmd(['adb', 'push', 'option.txt', '/data/local/tmp/MNN']) + run_cmd(['adb', 'push', 'OpenCLProgramBuildTest.out', '/data/local/tmp/MNN']) + res = run_cmd(['adb', 'shell', 'cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH && ./OpenCLProgramBuildTest.out %s'%(filename)]) + print(res) + else: + if sys.platform.startswith('win'): + res = run_cmd(['OpenCLProgramBuildTest.exe', f'{filename}']) + print(res) + else: + res = run_cmd(['./OpenCLProgramBuildTest.out', f'{filename}']) + print(res) + +def main(): + print("opencl_kernel_check.py path without_subgroup test_for_android") + path = '.' + without_subgroup = 1 + test_for_android = 0 + if len(sys.argv) > 1: + path = sys.argv[1] + + if len(sys.argv) > 2: + without_subgroup = int(sys.argv[2]) + + if len(sys.argv) > 3: + test_for_android = int(sys.argv[3]) + + binaryvec_operator = {"in0+in1", "in0*in1", "in0-in1", "in0>in1?in0:in1", "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))", + "in0>in1?in0:in1", "convert_float4(-isgreater(in0,in1))", "convert_float4(-isless(in0,in1))", "convert_float4(-islessequal(in0,in1))", "convert_float4(-isgreaterequal(in0,in1))", "convert_float4(-isequal(in0,in1))", + "floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", + "pow(in0,in1)", "(in0-in1)*(in0-in1)", "(in1==(float)0?(sign(in0)*(float4)(PI/2)):(atan(in0/in1)+(in1>(float4)0?(float4)0:sign(in0)*(float)PI)))", "convert_float4(-isnotequal(in0,in1))", + "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1"} + + binary_operator = {"in0*in1", "in0+in1", "in0-in1", "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))", "in0>in1?in1:in0", "in0>in1?in0:in1", "(float)(isgreater(in0,in1))", + "(float)(isless(in0,in1))", "(float)(islessequal(in0,in1))", "(float)(isgreaterequal(in0,in1))", "(float)(isequal(in0,in1))", "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))", + "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", "pow(in0,in1)", "(in0-in1)*(in0-in1)", + "(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))", "(float)(isnotequal(in0,in1))", "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1"} + + unary_operator = {"fabs(convert_float4(in))", "in*in", "rsqrt(convert_float4(in)>(float4)(0.000001)?convert_float4(in):(float4)(0.000001))", "-(in)", "exp(convert_float4(in))", "cos(convert_float4(in))", "sin(convert_float4(in))", + "tan(convert_float4(in))", "atan(convert_float4(in))", "sqrt(convert_float4(in))", "ceil(convert_float4(in))", "native_recip(convert_float4(in))", "log1p(convert_float4(in))", "native_log(convert_float4(in)>(float4)(0.0000001)?convert_float4(in):(float4)(0.0000001))", + "floor(convert_float4(in))", "in>(float4)((float)0)?(in+native_log(exp(convert_float4(-(in)))+(float4)(1.0))):(native_log(exp(convert_float4(in))+(float4)(1.0)))", "acosh(convert_float4(in))", "sinh(convert_float4(in))", "asinh(convert_float4(in))", + "atanh(convert_float4(in))", "sign(convert_float4(in))", "round(convert_float4(in))", "cosh(convert_float4(in))", "erf(convert_float4(in))", "erfc(convert_float4(in))", "expm1(convert_float4(in))", "native_recip((float4)1+native_exp(convert_float4(-in)))", + "(convert_float4(in)*native_recip((float4)1+native_exp(convert_float4(-in))))", "tanh(convert_float4(in))", "convert_float4(in)>(float4)(-3.0f)?(convert_float4(in)<(float4)(3.0f)?((convert_float4(in)*(convert_float4(in)+(float4)3.0f))/(float4)6.0f):convert_float4(in)):(float4)(0.0f)", + "gelu(convert_float4(in))", "(erf(convert_float4(in)*(float4)0.7071067932881648)+(float4)1.0)*convert_float4(in)*(float4)0.5", "native_recip((float4)(1.0)+native_exp(convert_float4(-(in))))", + "tanh(convert_float4(in))"} + + extra_macro = {} + extra_macro["binary_subgroup_buf.cl"] = " -DINTEL_DATA=uint -DAS_INPUT_DATA=as_float -DAS_INPUT_DATA4=as_float4 -DAS_OUTPUT_DATA4=as_uint4 -DINTEL_SUB_GROUP_READ=intel_sub_group_block_read -DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4 -DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4" + extra_macro["conv_2d_c1_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DINPUT_BLOCK_SIZE=16 -DINPUT_CHANNEL=16 -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1" + extra_macro["conv_2d_c16_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DINPUT_BLOCK_SIZE=16 -DINPUT_CHANNEL=16 -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1" + extra_macro["depthwise_conv2d_subgroup_buf.cl"] = " -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1" + extra_macro["matmul_local_buf.cl"] = " -DOPWM=64 -DOPWN=128 -DCPWK=8 -DOPTM=4 -DOPTN=8" + extra_macro["pooling_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DSTRIDE_Y=2 -DSTRIDE_X=2 -DKERNEL_Y=4 -DKERNEL_X=4" + extra_macro["reduction_buf.cl"] = " -DOPERATE(a,b)=(a+b) -DVALUE=0" + extra_macro["reduction.cl"] = " -DOPERATE(a,b)=(a+b) -DVALUE=0" + extra_macro["unary_subgroup_buf.cl"] = " -DINTEL_DATA=uint -DAS_INPUT_DATA=as_float -DAS_INPUT_DATA4=as_float4 -DAS_OUTPUT_DATA4=as_uint4 -DINTEL_SUB_GROUP_READ=intel_sub_group_block_read -DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4 -DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4" + + # 遍历当前目录的所有.cl文件 + for filename in os.listdir(path): + if filename.endswith('.cl'): + source_file = os.path.join(path, filename) + with open(source_file, 'r') as file: + file_content = file.read() + + with open('kernel.cl', 'w') as outfile: + outfile.write(file_content) + + # 提取宏定义 + macros_all = extract_macros(file_content) + # Compile with different macro values + operator_macro = {} + if filename == "binary_buf.cl" or filename == "binary.cl" or filename == "loop.cl" or filename == "binary_subgroup_buf.cl": + operator_macro = binaryvec_operator + elif filename == "loop_buf.cl": + operator_macro = binary_operator + elif filename == "unary_buf.cl" or filename == "unary.cl" or filename == "unary_subgroup_buf.cl": + operator_macro = unary_operator + + if "subgroup" in filename and without_subgroup == 1: + continue + compile_with_macros(macros_all, operator_macro, extra_macro, filename, test_for_android) + +if __name__ == "__main__": + main() diff --git a/transformers/llm/engine/CMakeLists.txt b/transformers/llm/engine/CMakeLists.txt index f271c961..1bb50996 100644 --- a/transformers/llm/engine/CMakeLists.txt +++ b/transformers/llm/engine/CMakeLists.txt @@ -1,6 +1,7 @@ option(LLM_SUPPORT_VISION "Llm model support vision input." OFF) option(LLM_SUPPORT_AUDIO "Llm model support audio input." OFF) option(BUILD_MLS "Build PC Commandline." OFF) +option(LLM_USE_MINJA "Use minja to apply template" ON) set(LLM_DEPS ${MNN_DEPS}) if (LLM_SUPPORT_VISION AND MNN_BUILD_OPENCV) @@ -22,7 +23,7 @@ endif() include_directories(${CMAKE_CURRENT_LIST_DIR}/include/) # source files -FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp ${CMAKE_CURRENT_LIST_DIR}/src/speculative_decoding/*.cpp) +FILE(GLOB_RECURSE SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*) if (MNN_SEP_BUILD) if (MNN_BUILD_SHARED_LIBS) @@ -37,12 +38,17 @@ if (MNN_SEP_BUILD) else() add_library(llm OBJECT ${SRCS}) endif() +if (LLM_USE_MINJA) + target_compile_options(llm PRIVATE -DLLM_USE_MINJA) + add_executable(apply_template ${CMAKE_CURRENT_LIST_DIR}/demo/apply_template.cpp) + target_link_libraries(apply_template ${LLM_DEPS}) +endif() if (LLM_SUPPORT_VISION AND MNN_BUILD_OPENCV) target_compile_definitions(llm PRIVATE LLM_SUPPORT_VISION) endif() if (LLM_SUPPORT_AUDIO AND MNN_BUILD_AUDIO) - target_compile_definitions(llm PRIVATE LLM_SUPPORT_AUDIO) + add_definitions(-DLLM_SUPPORT_AUDIO) endif() add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp) @@ -54,6 +60,7 @@ target_link_libraries(embedding_demo ${LLM_DEPS}) target_link_libraries(rollback_demo ${LLM_DEPS}) target_link_libraries(llm_bench ${LLM_DEPS}) + if (BUILD_MLS) set(CMAKE_OSX_DEPLOYMENT_TARGET "13.0" CACHE STRING "Minimum macOS version" FORCE) diff --git a/transformers/llm/engine/app/mls.cpp b/transformers/llm/engine/app/mls.cpp index 164d6e6e..4e717758 100644 --- a/transformers/llm/engine/app/mls.cpp +++ b/transformers/llm/engine/app/mls.cpp @@ -166,41 +166,28 @@ static int serve(int argc, const char *argv[]) { bool invalid_param{false}; std::string config_path{}; std::string arg{}; - std::string model_name{}; if (argc < 3) { print_usage(); return 1; } - + arg = argv[2]; + if (arg.find('-') != 0) { + config_path = (fs::path(mls::FileUtils::GetBaseCacheDir()) / arg / "config.json").string(); + } for (int i = 2; i < argc; i++) { arg = argv[i]; - - if (arg.find('-') != 0) { - model_name = arg; - config_path = (fs::path(mls::FileUtils::GetBaseCacheDir()) / arg / "config.json").string(); - continue; - } if (arg == "-c") { if (++i >= argc) { invalid_param = true; break; } config_path = mls::FileUtils::ExpandTilde(argv[i]); - model_name = mls::FileUtils::GetFileName(fs::path(config_path).parent_path()); - continue; } } - - if (invalid_param) { - std::cerr << "Error: Missing value after -c option" << std::endl; - print_usage(); - return 1; - } - mls::MlsServer server; bool is_r1 = IsR1(config_path); auto llm = create_and_prepare_llm(config_path.c_str(), !is_r1); - server.Start(model_name, llm.get(), is_r1); + server.Start(llm.get(), is_r1); return 0; } diff --git a/transformers/llm/engine/app/mls_server.cpp b/transformers/llm/engine/app/mls_server.cpp index 2ff2c1dc..44e6cd9f 100644 --- a/transformers/llm/engine/app/mls_server.cpp +++ b/transformers/llm/engine/app/mls_server.cpp @@ -21,96 +21,19 @@ std::string GetCurrentTimeAsString() { return std::to_string(seconds); } -// bool FromJson(const json& j, PromptItem& item) { -// if (!j.is_object()) { -// return false; -// } -// if (!j.contains("role") || !j["role"].is_string()) { -// return false; -// } -// if (!j.contains("content") || !j["content"].is_string()) { -// return false; -// } - -// item.first = j["role"].get(); // Role -// item.second = j["content"].get(); // Content -// return true; -// } - -std::string base64_decode(const std::string &ascdata) { - static const char b64_table[65] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - - static const char reverse_table[128] = { - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63, - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, - 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64, - 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, - 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64 - }; - - std::string retval; - const std::string::const_iterator last = ascdata.end(); - int bits_collected = 0; - unsigned int accumulator = 0; - - for (std::string::const_iterator i = ascdata.begin(); i != last; ++i) { - const int c = *i; - if (::std::isspace(c) || c == '=') { - // Skip whitespace and padding. Be liberal in what you accept. - continue; - } - if ((c > 127) || (c < 0) || (reverse_table[c] > 63)) { - MNN_ERROR("Base64 decode auth code failed.\n"); - return ""; - } - accumulator = (accumulator << 6) | reverse_table[c]; - bits_collected += 6; - if (bits_collected >= 8) { - bits_collected -= 8; - retval += static_cast((accumulator >> bits_collected) & 0xffu); - } - } - return retval; -} - -static std::string SaveImageFromDataUrl(const std::string& data_url) { - auto comma = data_url.find(","); - std::string b64 = (comma != std::string::npos ? data_url.substr(comma + 1) : data_url); - auto bytes = base64_decode(b64); - std::string filename = "image_" + GetCurrentTimeAsString() + ".jpg"; - std::ofstream ofs(filename, std::ios::binary); - ofs.write(reinterpret_cast(bytes.data()), bytes.size()); - ofs.close(); - return filename; -} - bool FromJson(const json& j, PromptItem& item) { - if (!j.is_object() || !j.contains("role")) return false; - item.first = j["role"].get(); - if (!j.contains("content")) return false; - - // Handle text or array content - if (j["content"].is_string()) { - item.second = j["content"].get(); - } else if (j["content"].is_array()) { - std::string combined; - for (const auto& elem : j["content"]) { - if (elem.is_object() && elem.value("type", "") == "image_url" - && elem.contains("image_url") - && elem["image_url"].contains("url")) { - std::string data_url = elem["image_url"]["url"].get(); - std::string path = SaveImageFromDataUrl(data_url); - combined += "" + path + ""; - } - // other types can be handled here - } - item.second = combined; - } else { - return false; + if (!j.is_object()) { + return false; } + if (!j.contains("role") || !j["role"].is_string()) { + return false; + } + if (!j.contains("content") || !j["content"].is_string()) { + return false; + } + + item.first = j["role"].get(); // Role + item.second = j["content"].get(); // Content return true; } @@ -229,12 +152,7 @@ void AllowCors(httplib::Response& res) { res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization"); } -std::vector MlsServer::GetLocalModels() { - return {current_model_name_}; -} - -void MlsServer::Start(const std::string& model_name, MNN::Transformer::Llm* llm, bool is_r1) { - this->current_model_name_ = model_name; +void MlsServer::Start(MNN::Transformer::Llm* llm, bool is_r1) { this->is_r1_ = is_r1; // Create a server instance httplib::Server server; @@ -244,146 +162,100 @@ void MlsServer::Start(const std::string& model_name, MNN::Transformer::Llm* llm, AllowCors(res); res.set_content(html_content, "text/html"); }); - - // Add OpenAI compatible /v1 routes for all APIs - auto handleChatCompletions = [&](const httplib::Request &req, httplib::Response &res, bool is_v1) { - if (!json::accept(req.body)) { - json err; - err["error"] = "Invalid JSON in request body."; - res.status = 400; - res.set_content(err.dump(), "application/json"); - return; - } - json request_json = json::parse(req.body, nullptr, false); - json messages = request_json["messages"]; - std::cout<<"received messages:"<(time(nullptr))}, - {"model", model}, - { - "choices", json::array({ - { - {"index", 0}, - { - "message", { - {"role", "assistant"}, - {"content", answer} - } - }, - {"finish_reason", "stop"} - } - }) - }, - { - "usage", { - {"prompt_tokens", 10}, - {"completion_tokens", 7}, - {"total_tokens", 17} - } - } - }; - res.set_content(response_json.dump(), "application/json"); - }); - return; - } - res.set_header("Content-Type", "text/event-stream"); - res.set_header("Cache-Control", "no-cache"); - res.set_header("Connection", "keep-alive"); - res.set_chunked_content_provider( - "text/event-stream", - [llm, messages, model, this](size_t /*offset*/, httplib::DataSink &sink) { - auto sse_callback = [&, this](const std::string &partial_text, bool end) { - std::string finish_reason = end ? "stop" : ""; - json sse_json = { - {"id", "chatcmpl-" + GetCurrentTimeAsString()}, - {"object", "chat.completion.chunk"}, - {"created", static_cast(std::time(nullptr))}, - {"model", model}, - {"choices", json::array({ - { - {"delta", {{"content", partial_text}}}, - {"index", 0}, - {"finish_reason", finish_reason} - } - })} - }; - std::string chunk_str = "data: " + sse_json.dump() + "\n\n"; - sink.os.write(chunk_str.c_str(), chunk_str.size()); - sink.os.flush(); - }; - AnswerStreaming(llm, messages, sse_callback); - std::string done_str = "data: [DONE]\n\n"; - sink.os.write(done_str.c_str(), done_str.size()); - sink.os.flush(); - sink.done(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - return false; - } - ); - }; - - // Route registrations - server.Get("/v1/", [this](const httplib::Request& req, httplib::Response& res) { - AllowCors(res); - res.set_content(html_content, "text/html"); - }); - server.Post("/reset", [&](const httplib::Request &req, httplib::Response &res) { - printf("POST /reset\n"); - llm->reset(); - res.set_content("{\"status\": \"ok\"}", "application/json"); + printf("POST /reset\n"); + AllowCors(res); + llm->reset(); + res.set_content("{\"status\": \"ok\"}", "application/json"); }); - server.Post("/v1/reset", [&](const httplib::Request &req, httplib::Response &res) { - printf("POST /v1/reset\n"); - llm->reset(); - res.set_content("{\"status\": \"ok\"}", "application/json"); - }); - server.Options("/chat/completions", [](const httplib::Request& /*req*/, httplib::Response& res) { AllowCors(res); res.status = 200; }); - server.Options("/v1/chat/completions", [](const httplib::Request& /*req*/, httplib::Response& res) { - AllowCors(res); - res.status = 200; - }); - - server.Get("/models/list", [this](const httplib::Request& req, httplib::Response& res) { - AllowCors(res); - std::vector model_names = GetLocalModels(); - json response = json::array(); - for (const auto& name : model_names) { - response.push_back(name); - } - res.set_content(response.dump(), "application/json"); - }); - server.Get("/v1/models", [this](const httplib::Request& req, httplib::Response& res) { - AllowCors(res); - std::vector model_names = GetLocalModels(); - json response = json::array(); - for (const auto& name : model_names) { - response.push_back(name); - } - res.set_content(response.dump(), "application/json"); - }); - server.Post("/chat/completions", [&](const httplib::Request &req, httplib::Response &res) { std::cout << "POST /chat/completions, handled by thread: " << std::this_thread::get_id() << std::endl; - handleChatCompletions(req, res, false); + AllowCors(res); + if (!json::accept(req.body)) { + json err; + err["error"] = "Invalid JSON in request body."; + res.status = 400; + res.set_content(err.dump(), "application/json"); + return; + } + json request_json = json::parse(req.body, nullptr, false); + json messages = request_json["messages"]; + std::cout<<"received messages:"<(time(nullptr))}, + {"model", model}, + { + "choices", json::array({ + { + {"index", 0}, + { + "message", { + {"role", "assistant"}, + {"content", answer} + } + }, + {"finish_reason", "stop"} + } + }) + }, + { + "usage", { + {"prompt_tokens", 10}, + {"completion_tokens", 7}, + {"total_tokens", 17} + } + } + }; + res.set_content(response_json.dump(), "application/json"); + }); + return; + } + res.set_header("Content-Type", "text/event-stream"); + res.set_header("Cache-Control", "no-cache"); + res.set_header("Connection", "keep-alive"); + res.set_chunked_content_provider( + "text/event-stream", + [llm, messages, model, this](size_t /*offset*/, httplib::DataSink &sink) { + auto sse_callback = [&, this](const std::string &partial_text, bool end) { + std::string finish_reason = end ? "stop" : ""; + json sse_json = { + {"id", "chatcmpl-" + GetCurrentTimeAsString()}, + {"object", "chat.completion.chunk"}, + {"created", static_cast(std::time(nullptr))}, + {"model", model}, + {"choices", json::array({ + { + {"delta", {{"content", partial_text}}}, + {"index", 0}, + {"finish_reason", finish_reason} + } + })} + }; + std::string chunk_str = "data: " + sse_json.dump() + "\n\n"; + sink.os.write(chunk_str.c_str(), chunk_str.size()); + sink.os.flush(); + }; + AnswerStreaming(llm, messages, sse_callback); + std::string done_str = "data: [DONE]\n\n"; + sink.os.write(done_str.c_str(), done_str.size()); + sink.os.flush(); + sink.done(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + return false; + } + ); }); - server.Post("/v1/chat/completions", [&](const httplib::Request &req, httplib::Response &res) { - std::cout << "POST /v1/chat/completions, handled by thread: " - << std::this_thread::get_id() << std::endl; - handleChatCompletions(req, res, true); - }); - // Start the server on port std::cout << "Starting server on http://localhost:9090\n"; if (!server.listen("0.0.0.0", 9090)) { diff --git a/transformers/llm/engine/app/mls_server.hpp b/transformers/llm/engine/app/mls_server.hpp index f57e3e00..ffd8531a 100644 --- a/transformers/llm/engine/app/mls_server.hpp +++ b/transformers/llm/engine/app/mls_server.hpp @@ -258,11 +258,9 @@ class MlsServer { )"""; - void Start(const std::string& modle_name, MNN::Transformer::Llm* llm, bool is_r1); + void Start(MNN::Transformer::Llm* llm, bool is_r1); bool is_r1_{false}; private: - std::string current_model_name_{}; - std::vector GetLocalModels(); void Answer(MNN::Transformer::Llm* llm, const json &messages, std::function on_result); void AnswerStreaming(MNN::Transformer::Llm* llm, const json& messages, diff --git a/transformers/llm/engine/demo/apply_template.cpp b/transformers/llm/engine/demo/apply_template.cpp new file mode 100644 index 00000000..d86790c8 --- /dev/null +++ b/transformers/llm/engine/demo/apply_template.cpp @@ -0,0 +1,193 @@ +#include +#include "../src/minja/chat_template.hpp" +#include +#include +#include +#include +#include +static int test(const char* testjson) { + rapidjson::Document document; + std::ifstream inputFs(testjson); + std::ostringstream osString; + if (inputFs.fail()) { + MNN_ERROR("Open %s error\n", testjson); + return 0; + } + osString << inputFs.rdbuf(); + document.Parse(osString.str().c_str()); + if (document.HasParseError() || (!document.IsArray())) { + MNN_ERROR("Invalid json\n"); + return 0; + } + int pos = 0; + for (auto& iter : document.GetArray()) { + std::string res = iter["res"].GetString(); + std::string chatTemplate = iter["chat_template"].GetString(); + std::string bos; + std::string eos; + if (iter.HasMember("bos")) { + bos = iter["bos"].GetString(); + } + if (iter.HasMember("eos")) { + eos = iter["eos"].GetString(); + } + minja::chat_template tmpl(chatTemplate, bos, eos); + minja::chat_template_inputs inputs; + inputs.messages.CopyFrom(iter["messages"], inputs.messages.GetAllocator()); + if (iter.HasMember("extras")) { + inputs.extra_context.CopyFrom(iter["extras"], inputs.extra_context.GetAllocator()); + } + inputs.add_generation_prompt = true; + auto newres = tmpl.apply(inputs); + if (res != newres) { + MNN_ERROR("Error for %d template\n", pos); + MNN_ERROR("Origin:\n%s\n", res.c_str()); + MNN_ERROR("Compute:\n%s\n", newres.c_str()); + return 0; + } + pos++; + } + MNN_PRINT("Test %d template, All Right\n", pos); + return 0; +} +int main(int argc, const char* argv[]) { + if (argc < 2) { + MNN_ERROR("Usage: ./apply_template token_config.json \n"); + MNN_ERROR("Or \n"); + MNN_ERROR("Usage: ./apply_template test.json 1\n"); + return 0; + } + if (argc >= 3) { + MNN_PRINT("Test %s\n", argv[1]); + test(argv[1]); + return 0; + } + rapidjson::Document resDocument; + { + // Load origin result + std::ifstream inputFs("result.json"); + bool valid = false; + if (!inputFs.fail()) { + std::ostringstream osString; + osString << inputFs.rdbuf(); + resDocument.Parse(osString.str().c_str()); + if (resDocument.HasParseError()) { + MNN_ERROR("Invalid json\n"); + } else { + valid = true; + MNN_PRINT("Has result.json, append it\n"); + } + } + if (!valid) { + resDocument.SetArray(); + MNN_PRINT("Create new result.json\n"); + } + } + for (int i=1; i>& value, std::string& dst) { + if (value.IsString()) { + dst = value.GetString(); + return; + } + if (value.IsObject()) { + if (value.HasMember("content") && value["content"].IsString()) { + dst = value["content"].GetString(); + return; + } + } + }; + if (document.HasMember("bos_token")) { + loadtoken(document["bos_token"], bosToken); + } + if (document.HasMember("eos_token")) { + loadtoken(document["eos_token"], eosToken); + } + std::string templateChat; + if (document.HasMember("chat_template")) { + templateChat = document["chat_template"].GetString(); + } + if (templateChat.empty()) { + MNN_ERROR("Invalid json with no chat_template\n"); + return 0; + } + + minja::chat_template tmpl(templateChat, bosToken, eosToken); + minja::chat_template_inputs inputs; + inputs.extra_context.SetObject(); + inputs.extra_context.GetObject().AddMember("enable_thinking", false, inputs.extra_context.GetAllocator()); + inputs.messages.Parse(R"([ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 8 * 12." + }, + { + "role": "assistant", + "content": "96." + }, + { + "role": "user", + "content": "What is 9 * 8?" + } + ])"); + inputs.add_generation_prompt = true; + auto res = tmpl.apply(inputs); + MNN_PRINT("%s", res.c_str()); + // Write result + rapidjson::Value v; + v.SetObject(); + rapidjson::Value messages; + messages.CopyFrom(inputs.messages, resDocument.GetAllocator()); + rapidjson::Value extras; + extras.CopyFrom(inputs.extra_context, resDocument.GetAllocator()); + v.AddMember("messages", messages, resDocument.GetAllocator()); + v.AddMember("extras", extras, resDocument.GetAllocator()); + { + rapidjson::Value tv; + tv.SetString(templateChat.c_str(), resDocument.GetAllocator()); + v.AddMember("chat_template", tv, resDocument.GetAllocator()); + } + if (!bosToken.empty()) { + rapidjson::Value tv; + tv.SetString(bosToken.c_str(), resDocument.GetAllocator()); + v.AddMember("bos", tv, resDocument.GetAllocator()); + } + if (!eosToken.empty()) { + rapidjson::Value tv; + tv.SetString(eosToken.c_str(), resDocument.GetAllocator()); + v.AddMember("eos", tv, resDocument.GetAllocator()); + } + { + rapidjson::Value tv; + tv.SetString(res.c_str(), resDocument.GetAllocator()); + v.AddMember("res", tv, resDocument.GetAllocator()); + } + resDocument.GetArray().PushBack(v, resDocument.GetAllocator()); + } + rapidjson::StringBuffer buf; + rapidjson::PrettyWriter bufwriter(buf); + resDocument.Accept(bufwriter); + std::ofstream os("result.json"); + os << buf.GetString(); + + return 0; +} diff --git a/transformers/llm/engine/demo/llm_demo.cpp b/transformers/llm/engine/demo/llm_demo.cpp index ea30ca3c..1a172e71 100644 --- a/transformers/llm/engine/demo/llm_demo.cpp +++ b/transformers/llm/engine/demo/llm_demo.cpp @@ -13,6 +13,9 @@ #include #include #include +#ifdef LLM_SUPPORT_AUDIO +#include "audio/audio.hpp" +#endif using namespace MNN::Transformer; static void tuning_prepare(Llm* llm) { @@ -77,6 +80,19 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ if (max_token_number > 0) { llm->set_config("{\"max_new_tokens\":1}"); } +#ifdef LLM_SUPPORT_AUDIO + std::vector waveform; + llm->setWavformCallback([&](const float* ptr, size_t size, bool last_chunk) { + waveform.reserve(waveform.size() + size); + waveform.insert(waveform.end(), ptr, ptr + size); + if (last_chunk) { + auto waveform_var = MNN::Express::_Const(waveform.data(), {(int)waveform.size()}, MNN::Express::NCHW, halide_type_of()); + MNN::AUDIO::save("output.wav", waveform_var, 24000); + waveform.clear(); + } + return true; + }); +#endif for (int i = 0; i < prompts.size(); i++) { const auto& prompt = prompts[i]; @@ -96,7 +112,6 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ } else { llm->response(prompt); } - llm->reset(); prompt_len += context->prompt_len; decode_len += context->gen_seq_len; vision_time += context->vision_us; @@ -105,6 +120,8 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ decode_time += context->decode_us; sample_time += context->sample_us; } + llm->generateWavform(); + float vision_s = vision_time / 1e6; float audio_s = audio_time / 1e6; float prefill_s = prefill_time / 1e6; @@ -170,12 +187,20 @@ static int eval(Llm* llm, std::string prompt_file, int max_token_number) { std::ifstream prompt_fs(prompt_file); std::vector prompts; std::string prompt; +//#define LLM_DEMO_ONELINE +#ifdef LLM_DEMO_ONELINE + std::ostringstream tempOs; + tempOs << prompt_fs.rdbuf(); + prompt = tempOs.str(); + prompts = {prompt}; +#else while (std::getline(prompt_fs, prompt)) { if (prompt.back() == '\r') { prompt.pop_back(); } prompts.push_back(prompt); } +#endif prompt_fs.close(); if (prompts.empty()) { return 1; @@ -240,6 +265,16 @@ int main(int argc, const char* argv[]) { std::istringstream os(argv[3]); os >> max_token_number; } + if (argc >= 5) { + MNN_PRINT("Set not thinking, only valid for Qwen3\n"); + llm->set_config(R"({ + "jinja": { + "context": { + "enable_thinking":false + } + } + })"); + } std::string prompt_file = argv[2]; return eval(llm.get(), prompt_file, max_token_number); } diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index a5c960c1..36c68214 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -81,7 +81,7 @@ public: virtual Express::VARP embedding(const std::vector& input_ids); Express::VARP forward(const std::vector& input_ids, bool is_prefill = true); Express::VARP forward(MNN::Express::VARP input_embeds); - virtual Express::VARP forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos); + virtual std::vector forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos); virtual int sample(Express::VARP logits, int offset = 0, int size = 0); void reset(); void tuning(TuneType type, std::vector candidates); @@ -130,8 +130,9 @@ protected: std::vector> mModules, mPrefillModules, mDecodeModules, mCurrentModules; const Express::Module* mBaseModule = nullptr; Express::VARP inputsEmbeds, attentionMask, positionIds; - std::vector mInputsEmbedsVarVec, mAttentionMaskVarVec, mPositionIdsVarVec; + std::vector mAttentionMaskVarVec, mPositionIdsVarVec; Express::VARP logitsAllIdx, logitsLastIdx; + int mSeqLenIndex = 0; private: // decoding phase will use speculative decoding void speculativeGenerate(int max_token); diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 947c7341..f4fa05f9 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -10,7 +10,6 @@ #include #include #include - #include #include #include "cpp/ExprDebug.hpp" @@ -261,7 +260,6 @@ void Llm::load() { logitsLastIdx = _var({-1}, {1}); logitsAllIdx = _var({0}, {1}); // index match with seq_len - mInputsEmbedsVarVec.resize(decode_type_num); mAttentionMaskVarVec.resize(decode_type_num); mPositionIdsVarVec.resize(decode_type_num); for(int i = 0; i < decode_type_num; i++) { @@ -281,7 +279,6 @@ void Llm::load() { } mPositionIdsVarVec[i] = _Input({index}, NCHW, halide_type_of()); - mInputsEmbedsVarVec[i] = _Input({index, 1, mConfig->hidden_size()}, NCHW); } } @@ -345,6 +342,8 @@ void Llm::tuning(TuneType type, std::vector candidates) { } mCurrentModules = mDecodeModules; int decode_seq = 1; + // Set to decode mode + mContext->gen_seq_len = 1; if(mLookAhead) { // start autoregressive decoding std::vector input_ids = {0}; @@ -402,9 +401,10 @@ void Llm::setKVCacheInfo(size_t add, size_t remove, int* reserve, int n_reserve) mMeta->add = add; } -Express::VARP Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) { +std::vector Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) { VARP logits; Express::VARP logitsIndex; + bool inDecode = mContext->gen_seq_len > 0; // TODO : to be improved if (mConfig->all_logits() || (mLookAhead && mCurrentModules.size() > 1)) { logitsIndex = logitsAllIdx; @@ -420,7 +420,7 @@ Express::VARP Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Exp outputs = mCurrentModules.back()->onForward({hiddenState, mask, inputPos, logitsIndex}); } if (outputs.empty()) { - return nullptr; + return outputs; } logits = outputs[0]; @@ -477,9 +477,8 @@ Express::VARP Llm::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Exp } } #endif - mMeta->sync(); - return logits; + return outputs; } VARP Llm::forward(const std::vector& input_ids, bool is_prefill) { @@ -488,11 +487,15 @@ VARP Llm::forward(const std::vector& input_ids, bool is_prefill) { } VARP Llm::forward(MNN::Express::VARP input_embeds) { - int seq_len = input_embeds->getInfo()->dim[0]; + int seq_len = input_embeds->getInfo()->dim[mSeqLenIndex]; mMeta->add = seq_len; auto attention_mask = gen_attention_mask(seq_len); auto position_ids = gen_position_ids(seq_len); - auto logits = forwardRaw(input_embeds, attention_mask, position_ids); + auto out = forwardRaw(input_embeds, attention_mask, position_ids); + if (out.empty()) { + return nullptr; + } + auto logits = out[0]; mContext->all_seq_len += seq_len; mContext->gen_seq_len++; return logits; @@ -511,6 +514,7 @@ void Llm::reset() { mContext->output_tokens.clear(); mContext->history_tokens.clear(); mContext->all_seq_len = 0; + mMeta->remove = mMeta->previous; } void Llm::generate_init(std::ostream* os, const char* end_with) { @@ -728,16 +732,7 @@ VARP Llm::embedding(const std::vector& input_ids) { AUTOTIME; int hidden_size = mConfig->hidden_size(); int seq_len = static_cast(input_ids.size()); - if (mInputsEmbedsVarVec.size() > 0) { - if(seq_len == 1) { - mDiskEmbedding->embedding(input_ids, mInputsEmbedsVarVec[0]->writeMap()); - return mInputsEmbedsVarVec[0]; - } - if(mInputsEmbedsVarVec.size() > 1 && seq_len == mDraftLength) { - mDiskEmbedding->embedding(input_ids, mInputsEmbedsVarVec[1]->writeMap()); - return mInputsEmbedsVarVec[1]; - } - } + VARP res = _Input({seq_len, 1, hidden_size}, NCHW); // disk embedding to save memory mDiskEmbedding->embedding(input_ids, res->writeMap()); @@ -755,27 +750,24 @@ std::string Llm::tokenizer_decode(int id) { } VARP Llm::gen_attention_mask(int seq_len) { - bool useSquareMask = false; int kv_seq_len = mContext->all_seq_len + seq_len; if (mConfig->attention_mask() == "float") { - // currently only metal supoort square mask - useSquareMask = (mConfig->backend_type() == "metal"); - if(useSquareMask) { - kv_seq_len = seq_len; - } - if(seq_len == 1) { - return mAttentionMaskVarVec[0]; - } - if (useSquareMask && mAttentionMaskVarVec.size() > 1 && seq_len == mDraftLength) { - return mAttentionMaskVarVec[1]; + // Use square mask + kv_seq_len = seq_len; + if (mAttentionMaskVarVec.size() > 0) { + if(seq_len == 1) { + return mAttentionMaskVarVec[0]; + } + if (mAttentionMaskVarVec.size() > 1 && seq_len == mDraftLength) { + return mAttentionMaskVarVec[1]; + } } attentionMask = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of()); auto ptr = attentionMask->writeMap(); for (int i = 0; i < seq_len; i++) { for (int j = 0; j < kv_seq_len; j++) { - int row = useSquareMask ? i : i + mContext->all_seq_len; - ptr[kv_seq_len * i + j] = (j > row) * std::numeric_limits::lowest(); + ptr[kv_seq_len * i + j] = (j > i) * std::numeric_limits::lowest(); } } return attentionMask; diff --git a/transformers/llm/engine/src/minja/chat_template.cpp b/transformers/llm/engine/src/minja/chat_template.cpp new file mode 100644 index 00000000..28a3a5b0 --- /dev/null +++ b/transformers/llm/engine/src/minja/chat_template.cpp @@ -0,0 +1,746 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#ifdef LLM_USE_MINJA +#include "minja.hpp" +#include "chat_template.hpp" +namespace minja { + // Helper to convert Value to string + static std::string valueToString(const rapidjson::Value& val) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + val.Accept(writer); + return buffer.GetString(); + } + chat_template::chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) + : source_(source), bos_token_(bos_token), eos_token_(eos_token) + { + template_root_ = minja::parse(source_, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); +#define MINJA_ADD_TEST +#ifdef MINJA_ADD_TEST + auto contains = [](const std::string & haystack, const std::string & needle) { + return haystack.find(needle) != std::string::npos; + }; + + // This entire block needs to be refactored to use rapidjson. + // This is a significant change due to how objects and arrays are constructed. + // I will need a Document (with its allocator) for each JSON structure. + Document d_render_test; // Document for constructing test JSONs + auto& alloc = d_render_test.GetAllocator(); + + const std::string user_needle = ""; + const std::string sys_needle = ""; + + // const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}}; + rapidjson::Value dummy_str_user_msg(rapidjson::kObjectType); + dummy_str_user_msg.AddMember("role", "user", alloc); + dummy_str_user_msg.AddMember("content", rapidjson::StringRef(user_needle.c_str()), alloc); + // const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}}; + rapidjson::Value dummy_typed_user_msg(rapidjson::kObjectType); + dummy_typed_user_msg.AddMember("role", "user", alloc); + { + rapidjson::Value content_array(rapidjson::kArrayType); + rapidjson::Value content_item(rapidjson::kObjectType); + content_item.AddMember("type", "text", alloc); + content_item.AddMember("text", rapidjson::StringRef(user_needle.c_str()), alloc); + content_array.PushBack(content_item, alloc); + dummy_typed_user_msg.AddMember("content", content_array, alloc); + } + rapidjson::Value dummy_str_user_msg_copy1; + dummy_str_user_msg_copy1.CopyFrom(dummy_str_user_msg, alloc); + + rapidjson::Value messages_for_render1(rapidjson::kArrayType); + messages_for_render1.PushBack(dummy_str_user_msg_copy1, alloc); + rapidjson::Value no_tools(rapidjson::kArrayType); // Assuming empty array for no tools + + // For capability detection, polyfills are off, so copy is fine. + rapidjson::Value messages_typed_content_test1(rapidjson::kArrayType); + dummy_str_user_msg_copy1.CopyFrom(dummy_str_user_msg, alloc); + messages_typed_content_test1.PushBack(dummy_str_user_msg_copy1, alloc); + rapidjson::Value no_tools_copy1; no_tools_copy1.CopyFrom(no_tools, alloc); + + rapidjson::Value dummy_typed_user_msg_copy1; + dummy_typed_user_msg_copy1.CopyFrom(dummy_typed_user_msg, alloc); + rapidjson::Value messages_typed_content_test2(rapidjson::kArrayType); + messages_typed_content_test2.PushBack(dummy_typed_user_msg_copy1, alloc); + rapidjson::Value no_tools_copy2; no_tools_copy2.CopyFrom(no_tools, alloc); + + caps_.requires_typed_content = + !contains(try_raw_render(messages_typed_content_test1, no_tools_copy1, false, alloc), user_needle) && + contains(try_raw_render(messages_typed_content_test2, no_tools_copy2, false, alloc), user_needle); + + // const auto dummy_user_msg = caps_.requires_typed_content ? dummy_typed_user_msg : dummy_str_user_msg; + rapidjson::Value dummy_user_msg(rapidjson::kObjectType); + if (caps_.requires_typed_content) { + dummy_user_msg.CopyFrom(dummy_typed_user_msg, alloc); + } else { + dummy_user_msg.CopyFrom(dummy_str_user_msg, alloc); + } + + // const json needle_system_msg = { + // {"role", "system"}, + // {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)}, + // }; + rapidjson::Value needle_system_msg(rapidjson::kObjectType); + needle_system_msg.AddMember("role", "system", alloc); + if (caps_.requires_typed_content) { + rapidjson::Value content_array_sys(rapidjson::kArrayType); + rapidjson::Value content_item_sys(rapidjson::kObjectType); + content_item_sys.AddMember("type", "text", alloc); + content_item_sys.AddMember("text", rapidjson::StringRef(sys_needle.c_str()), alloc); + content_array_sys.PushBack(content_item_sys, alloc); + needle_system_msg.AddMember("content", content_array_sys, alloc); + } else { + needle_system_msg.AddMember("content", rapidjson::StringRef(sys_needle.c_str()), alloc); + } + + // caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle); + rapidjson::Value messages_for_sys_role_test(rapidjson::kArrayType); + rapidjson::Value needle_system_msg_copy; needle_system_msg_copy.CopyFrom(needle_system_msg, alloc); + rapidjson::Value dummy_user_msg_copy2; dummy_user_msg_copy2.CopyFrom(dummy_user_msg, alloc); + messages_for_sys_role_test.PushBack(needle_system_msg_copy, alloc); + messages_for_sys_role_test.PushBack(dummy_user_msg_copy2, alloc); + rapidjson::Value no_tools_copy3; no_tools_copy3.CopyFrom(no_tools, alloc); + caps_.supports_system_role = contains(try_raw_render(messages_for_sys_role_test, no_tools_copy3, false, alloc), sys_needle); + + // auto out = try_raw_render(json::array({dummy_user_msg}), json::array({...}), false); + rapidjson::Value messages_for_tools_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy3; dummy_user_msg_copy3.CopyFrom(dummy_user_msg, alloc); + messages_for_tools_test.PushBack(dummy_user_msg_copy3, alloc); + + rapidjson::Value tools_for_test(rapidjson::kArrayType); + rapidjson::Value tool_def(rapidjson::kObjectType); + tool_def.AddMember("name", "some_tool", alloc); + tool_def.AddMember("type", "function", alloc); + rapidjson::Value function_def(rapidjson::kObjectType); + function_def.AddMember("name", "some_tool", alloc); + function_def.AddMember("description", "Some tool.", alloc); + rapidjson::Value params_def(rapidjson::kObjectType); + params_def.AddMember("type", "object", alloc); + rapidjson::Value props_def(rapidjson::kObjectType); + rapidjson::Value arg_def(rapidjson::kObjectType); + arg_def.AddMember("type", "string", alloc); + arg_def.AddMember("description", "Some argument.", alloc); + props_def.AddMember("arg", arg_def, alloc); + params_def.AddMember("properties", props_def, alloc); + rapidjson::Value required_arr(rapidjson::kArrayType); + required_arr.PushBack("arg", alloc); + params_def.AddMember("required", required_arr, alloc); + function_def.AddMember("parameters", params_def, alloc); + tool_def.AddMember("function", function_def, alloc); + tools_for_test.PushBack(tool_def, alloc); + + std::string out_tools_test = try_raw_render(tools_for_test, tools_for_test, false, alloc); + caps_.supports_tools = contains(out_tools_test, "some_tool"); + + // auto make_tool_calls_msg = [&](const json & tool_calls) { ... } + auto make_tool_calls_msg_rj = [&](rapidjson::Value& tool_calls_val, rapidjson::Document::AllocatorType& allocator_func) { + rapidjson::Value msg(rapidjson::kObjectType); + msg.AddMember("role", "assistant", allocator_func); + msg.AddMember("content", rapidjson::Value(rapidjson::kNullType), allocator_func); + msg.AddMember("tool_calls", tool_calls_val, allocator_func); // tool_calls_val is already using alloc from caller + return msg; + }; + + // auto make_tool_call = [](const std::string & tool_name, const json & arguments) { ... } + auto make_tool_call_rj = [&](const char* tool_name_str, rapidjson::Value& arguments_val, rapidjson::Document::AllocatorType& allocator_func) { + rapidjson::Value tc(rapidjson::kObjectType); + tc.AddMember("id", "call_1___", allocator_func); + tc.AddMember("type", "function", allocator_func); + rapidjson::Value func(rapidjson::kObjectType); + func.AddMember("arguments", arguments_val, allocator_func); // arguments_val is already using alloc from caller + func.AddMember("name", rapidjson::StringRef(tool_name_str), allocator_func); + tc.AddMember("function", func, allocator_func); + return tc; + }; + + // const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}}; + rapidjson::Value dummy_args_obj_rj(rapidjson::kObjectType); + dummy_args_obj_rj.AddMember("argument_needle", "print('Hello, World!')", alloc); + + // Convert dummy_args_obj_rj to string for the first test + rapidjson::StringBuffer buffer_args_str; + rapidjson::Writer writer_args_str(buffer_args_str); + dummy_args_obj_rj.Accept(writer_args_str); + std::string dummy_args_obj_as_string = buffer_args_str.GetString(); + rapidjson::Value dummy_args_str_val(dummy_args_obj_as_string.c_str(), alloc); + + + // out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})) }), {}, false); + rapidjson::Value messages_for_tool_call_str_args_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy4; dummy_user_msg_copy4.CopyFrom(dummy_user_msg, alloc); + messages_for_tool_call_str_args_test.PushBack(dummy_user_msg_copy4, alloc); + rapidjson::Value tool_calls_array1(rapidjson::kArrayType); + rapidjson::Value tc1_args_str; tc1_args_str.CopyFrom(dummy_args_str_val, alloc); // Already a string value + std::string ipython = "ipython"; + tool_calls_array1.PushBack(make_tool_call_rj(ipython.c_str(), tc1_args_str, alloc), alloc); + rapidjson::Value tool_calls_msg1 = make_tool_calls_msg_rj(tool_calls_array1, alloc); + messages_for_tool_call_str_args_test.PushBack(tool_calls_msg1, alloc); + rapidjson::Value no_tools_copy4; no_tools_copy4.CopyFrom(no_tools, alloc); + std::string out_tool_call_str_args = try_raw_render(messages_for_tool_call_str_args_test, no_tools_copy4, false, alloc); + bool tool_call_renders_str_arguments = contains(out_tool_call_str_args, "\"argument_needle\":") || contains(out_tool_call_str_args, "'argument_needle':"); + + // out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})) }), {}, false); + rapidjson::Value messages_for_tool_call_obj_args_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy5; dummy_user_msg_copy5.CopyFrom(dummy_user_msg, alloc); + messages_for_tool_call_obj_args_test.PushBack(dummy_user_msg_copy5, alloc); + rapidjson::Value tool_calls_array2(rapidjson::kArrayType); + rapidjson::Value tc1_args_obj; tc1_args_obj.CopyFrom(dummy_args_obj_rj, alloc); + tool_calls_array2.PushBack(make_tool_call_rj("ipython", tc1_args_obj, alloc), alloc); + rapidjson::Value tool_calls_msg2 = make_tool_calls_msg_rj(tool_calls_array2, alloc); + messages_for_tool_call_obj_args_test.PushBack(tool_calls_msg2, alloc); + rapidjson::Value no_tools_copy5; no_tools_copy5.CopyFrom(no_tools, alloc); + std::string out_tool_call_obj_args = try_raw_render(messages_for_tool_call_obj_args_test, no_tools_copy5, false, alloc); + bool tool_call_renders_obj_arguments = contains(out_tool_call_obj_args, "\"argument_needle\":") || contains(out_tool_call_obj_args, "'argument_needle':"); + + caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; + caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; + + // auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); + rapidjson::Value messages_for_empty_content_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy6; dummy_user_msg_copy6.CopyFrom(dummy_user_msg, alloc); + messages_for_empty_content_test.PushBack(dummy_user_msg_copy6, alloc); + rapidjson::Value assistant_msg_empty_content(rapidjson::kObjectType); + assistant_msg_empty_content.AddMember("role", "assistant", alloc); + assistant_msg_empty_content.AddMember("content", "", alloc); + messages_for_empty_content_test.PushBack(assistant_msg_empty_content, alloc); + rapidjson::Value no_tools_copy6; no_tools_copy6.CopyFrom(no_tools, alloc); + std::string out_empty_content = try_raw_render(messages_for_empty_content_test, no_tools_copy6, false, alloc); + + // auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + rapidjson::Value messages_for_null_content_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy7; dummy_user_msg_copy7.CopyFrom(dummy_user_msg, alloc); + messages_for_null_content_test.PushBack(dummy_user_msg_copy7, alloc); + rapidjson::Value assistant_msg_null_content(rapidjson::kObjectType); + assistant_msg_null_content.AddMember("role", "assistant", alloc); + assistant_msg_null_content.AddMember("content", rapidjson::Value(rapidjson::kNullType), alloc); + messages_for_null_content_test.PushBack(assistant_msg_null_content, alloc); + rapidjson::Value no_tools_copy7; no_tools_copy7.CopyFrom(no_tools, alloc); + std::string out_null_content = try_raw_render(messages_for_null_content_test, no_tools_copy7, false, alloc); + caps_.requires_non_null_content = contains(out_empty_content, user_needle) && !contains(out_null_content, user_needle); + + + if (caps_.supports_tool_calls) { + // auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); + rapidjson::Value dummy_args_for_parallel_test; + if (caps_.requires_object_arguments) { + dummy_args_for_parallel_test.CopyFrom(dummy_args_obj_rj, alloc); + } else { + // This was already created: dummy_args_str_val (string version of dummy_args_obj_rj) + dummy_args_for_parallel_test.CopyFrom(dummy_args_str_val, alloc); + } + + // auto tc1 = make_tool_call("test_tool1", dummy_args); + // auto tc2 = make_tool_call("test_tool2", dummy_args); + rapidjson::Value dummy_args_tc1; dummy_args_tc1.CopyFrom(dummy_args_for_parallel_test, alloc); + rapidjson::Value tc1 = make_tool_call_rj("test_tool1", dummy_args_tc1, alloc); + rapidjson::Value dummy_args_tc2; dummy_args_tc2.CopyFrom(dummy_args_for_parallel_test, alloc); + rapidjson::Value tc2 = make_tool_call_rj("test_tool2", dummy_args_tc2, alloc); + + // auto out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({tc1, tc2})) }), {}, false); + rapidjson::Value messages_for_parallel_calls_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy8; dummy_user_msg_copy8.CopyFrom(dummy_user_msg, alloc); + messages_for_parallel_calls_test.PushBack(dummy_user_msg_copy8, alloc); + rapidjson::Value tool_calls_array_parallel(rapidjson::kArrayType); + tool_calls_array_parallel.PushBack(tc1, alloc); // tc1, tc2 are already using alloc + tool_calls_array_parallel.PushBack(tc2, alloc); + rapidjson::Value tool_calls_msg_parallel = make_tool_calls_msg_rj(tool_calls_array_parallel, alloc); + messages_for_parallel_calls_test.PushBack(tool_calls_msg_parallel, alloc); + rapidjson::Value no_tools_copy8; no_tools_copy8.CopyFrom(no_tools, alloc); + std::string out_parallel_calls = try_raw_render(messages_for_parallel_calls_test, no_tools_copy8, false, alloc); + caps_.supports_parallel_tool_calls = contains(out_parallel_calls, "test_tool1") && contains(out_parallel_calls, "test_tool2"); + + // Need to re-create tc1 as it was moved into tool_calls_array_parallel + rapidjson::Value dummy_args_tc1_resp; dummy_args_tc1_resp.CopyFrom(dummy_args_for_parallel_test, alloc); + rapidjson::Value tc1_resp = make_tool_call_rj("test_tool1", dummy_args_tc1_resp, alloc); + + // out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({tc1})), { ...tool response... } }), {}, false); + rapidjson::Value messages_for_tool_response_test(rapidjson::kArrayType); + rapidjson::Value dummy_user_msg_copy9; dummy_user_msg_copy9.CopyFrom(dummy_user_msg, alloc); + messages_for_tool_response_test.PushBack(dummy_user_msg_copy9, alloc); + rapidjson::Value tool_calls_array_resp(rapidjson::kArrayType); + tool_calls_array_resp.PushBack(tc1_resp, alloc); + rapidjson::Value tool_calls_msg_resp = make_tool_calls_msg_rj(tool_calls_array_resp, alloc); + messages_for_tool_response_test.PushBack(tool_calls_msg_resp, alloc); + rapidjson::Value tool_response_msg(rapidjson::kObjectType); + tool_response_msg.AddMember("role", "tool", alloc); + tool_response_msg.AddMember("name", "test_tool1", alloc); + tool_response_msg.AddMember("content", "Some response!", alloc); + tool_response_msg.AddMember("tool_call_id", "call_911_", alloc); + messages_for_tool_response_test.PushBack(tool_response_msg, alloc); + rapidjson::Value no_tools_copy9; no_tools_copy9.CopyFrom(no_tools, alloc); + std::string out_tool_response = try_raw_render(messages_for_tool_response_test, no_tools_copy9, false, alloc); + caps_.supports_tool_responses = contains(out_tool_response, "Some response!"); + caps_.supports_tool_call_id = contains(out_tool_response, "call_911_"); + } + + if (!caps_.supports_tools) { + // const json user_msg { {"role", "user"}, {"content", "Hey"} }; + rapidjson::Value user_msg_infer(rapidjson::kObjectType); + user_msg_infer.AddMember("role", "user", alloc); + user_msg_infer.AddMember("content", "Hey", alloc); + + // const json args { {"arg1", "some_value"} }; + rapidjson::Value args_infer(rapidjson::kObjectType); + args_infer.AddMember("arg1", "some_value", alloc); + + // const json tool_call_msg { ... } + rapidjson::Value tool_call_msg_infer(rapidjson::kObjectType); + tool_call_msg_infer.AddMember("role", "assistant", alloc); + tool_call_msg_infer.AddMember("content", rapidjson::Value(rapidjson::kNullType), alloc); + rapidjson::Value tool_calls_array_infer(rapidjson::kArrayType); + rapidjson::Value tool_call_item_infer(rapidjson::kObjectType); + tool_call_item_infer.AddMember("id", "call_1___", alloc); + tool_call_item_infer.AddMember("type", "function", alloc); + rapidjson::Value function_item_infer(rapidjson::kObjectType); + function_item_infer.AddMember("name", "tool_name", alloc); + + rapidjson::Value arguments_infer; + if (caps_.requires_object_arguments) { + arguments_infer.CopyFrom(args_infer, alloc); + } else { + // This requires minja::Value::dump which itself uses nlohmann::json. + // This part needs a temporary nlohmann::json to dump, or reimplement dump logic for rapidjson. + // For now, let's assume minja::Value can give us a string that rapidjson can parse, + // or we construct the string directly. + // minja::Value(args).dump(-1, /* to_json= */ true) + // This is a major dependency. For now, I'll create a simple string version. + rapidjson::StringBuffer buffer_args_infer_str; + rapidjson::Writer writer_args_infer_str(buffer_args_infer_str); + args_infer.Accept(writer_args_infer_str); + arguments_infer.SetString(buffer_args_infer_str.GetString(), alloc); + } + function_item_infer.AddMember("arguments", arguments_infer, alloc); + tool_call_item_infer.AddMember("function", function_item_infer, alloc); + tool_calls_array_infer.PushBack(tool_call_item_infer, alloc); + tool_call_msg_infer.AddMember("tool_calls", tool_calls_array_infer, alloc); + + std::string prefix_str, full_str; + { + chat_template_inputs inputs_prefix; + inputs_prefix.allocator_for_inputs = &alloc; + inputs_prefix.messages.SetArray(); + rapidjson::Value user_msg_infer_copy1; user_msg_infer_copy1.CopyFrom(user_msg_infer, alloc); + inputs_prefix.messages.PushBack(user_msg_infer_copy1, alloc); + inputs_prefix.add_generation_prompt = true; + // inputs.tools is already kNullType by default in chat_template_inputs constructor + prefix_str = apply(inputs_prefix); + } + { + chat_template_inputs inputs_full; + inputs_full.allocator_for_inputs = &alloc; + inputs_full.messages.SetArray(); + rapidjson::Value user_msg_infer_copy2; user_msg_infer_copy2.CopyFrom(user_msg_infer, alloc); + inputs_full.messages.PushBack(user_msg_infer_copy2, alloc); + rapidjson::Value tool_call_msg_infer_copy; tool_call_msg_infer_copy.CopyFrom(tool_call_msg_infer, alloc); + inputs_full.messages.PushBack(tool_call_msg_infer_copy, alloc); + inputs_full.add_generation_prompt = false; + // inputs.tools is already kNullType by default + full_str = apply(inputs_full); + } + // ... rest of the logic for tool_call_example_ using prefix_str and full_str + // This part seems okay to remain as string manipulation + auto eos_pos_last = full_str.rfind(eos_token_); + if (eos_pos_last == prefix_str.size() - eos_token_.size() || + (full_str[full_str.size() - 1] == '\n' && (eos_pos_last == full_str.size() - eos_token_.size() - 1))) { + full_str = full_str.substr(0, eos_pos_last); + } + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix_str.size() && i < full_str.size(); ++i) { + if (prefix_str[i] != full_str[i]) { + break; + } + if (prefix_str[i] == '<') { + continue; + } + common_prefix_length = i + 1; + } + auto example = full_str.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; + } + } + // Ensure d_render_test is cleared if it were a member, but it's local. +#endif + } + std::string chat_template::try_raw_render( + rapidjson::Value& messages, // Modifying to pass by ref as it might be changed by polyfills later + rapidjson::Value& tools, // Modifying to pass by ref + bool add_generation_prompt, + rapidjson::Document::AllocatorType& allocator, // Added allocator + rapidjson::Value extra_context) const // Default to null + { + chat_template_inputs inputs; + // Important: When assigning Value, if it's from another Document or a temporary, + // it needs to be deep copied using the allocator of the target Document/Value. + // For try_raw_render, we assume messages, tools, extra_context are already managed + // or will be properly constructed with an allocator. + // Here, we're creating new Value objects for the inputs struct, so they need an allocator + // if they are to be populated. However, inputs here is temporary. + // The original nlohmann version copied, rapidjson Value assignment is a shallow copy. + // This needs careful handling. For now, let's assume the caller manages lifetime. + // This is tricky because the Value objects in chat_template_inputs need an allocator. + // Let's try to pass the allocator to inputs. + inputs.allocator_for_inputs = &allocator; + inputs.messages.CopyFrom(messages, allocator); + inputs.tools.CopyFrom(tools, allocator); + inputs.add_generation_prompt = add_generation_prompt; + if (!extra_context.IsNull()) { + inputs.extra_context.CopyFrom(extra_context, allocator); + } else { + inputs.extra_context.SetObject(); // Initialize as empty object if default + } + // Use fixed date for tests + inputs.now = std::chrono::system_clock::from_time_t(0); + + chat_template_options opts; + opts.apply_polyfills = false; + + auto prompt = apply(inputs, opts); + // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); + return prompt; + } + std::string chat_template::apply( + chat_template_inputs & inputs, + const chat_template_options & opts) const { + AUTOTIME; + // Create a working document for this apply call. + // All new JSON Values created within this scope should use its allocator. + Document working_doc; + rapidjson::Document::AllocatorType& allocator = working_doc.GetAllocator(); + + rapidjson::Value actual_messages(rapidjson::kArrayType); // Uses working_doc's allocator by default if created here + + auto has_tools = inputs.tools.IsArray() && !inputs.tools.Empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + + if (inputs.messages.IsArray()) { + for (const auto & message_val : inputs.messages.GetArray()) { + if (message_val.IsObject()) { + if (message_val.HasMember("tool_calls") && !message_val["tool_calls"].IsNull()) { + has_tool_calls = true; + } + if (message_val.HasMember("role") && message_val["role"].IsString() && + strcmp(message_val["role"].GetString(), "tool") == 0) { + has_tool_responses = true; + } + if (message_val.HasMember("content") && message_val["content"].IsString()) { + has_string_content = true; + } + } + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + + auto needs_polyfills = opts.apply_polyfills && (false + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content + ); + + if (needs_polyfills) { + // actual_messages is already an empty array, using allocator + + auto add_message = [&](const rapidjson::Value & msg_const) { + rapidjson::Value msg; + msg.CopyFrom(msg_const, allocator); // Ensure it uses the current doc's allocator + + if (polyfill_typed_content && msg.IsObject() && msg.HasMember("content") && + !msg["content"].IsNull() && msg["content"].IsString()) { + + rapidjson::Value new_msg(rapidjson::kObjectType); + new_msg.AddMember("role", rapidjson::Value(msg["role"], allocator), allocator); // copy role + + rapidjson::Value content_array_typed(rapidjson::kArrayType); + rapidjson::Value content_item_typed(rapidjson::kObjectType); + content_item_typed.AddMember("type", "text", allocator); + // Need to copy the string content for "text" + rapidjson::Value text_val(msg["content"].GetString(), allocator); + content_item_typed.AddMember("text", text_val, allocator); + content_array_typed.PushBack(content_item_typed, allocator); + new_msg.AddMember("content", content_array_typed, allocator); + actual_messages.PushBack(new_msg, allocator); + } else { + actual_messages.PushBack(msg, allocator); // msg already copied with allocator + } + }; + + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + rapidjson::Value sys_as_user_msg(rapidjson::kObjectType); + sys_as_user_msg.AddMember("role", "user", allocator); + sys_as_user_msg.AddMember("content", rapidjson::StringRef(pending_system.c_str()), allocator); + add_message(sys_as_user_msg); // add_message will handle typed content if needed + pending_system.clear(); + } + }; + + rapidjson::Value adjusted_messages_val(rapidjson::kArrayType); + if (polyfill_tools) { + // Convert inputs.tools to string for the system prompt + rapidjson::StringBuffer tools_buffer; + rapidjson::PrettyWriter tools_writer(tools_buffer); // Pretty for readability + tools_writer.SetIndent(' ', 2); + inputs.tools.Accept(tools_writer); + std::string tools_str_prompt = tools_buffer.GetString(); + + std::string system_prompt_str = + "You can call any of the following tools to satisfy the user's requests: " + tools_str_prompt + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + + // add_system returns a new Value, ensure it uses 'allocator' + rapidjson::Value messages_copy_for_add_system; + messages_copy_for_add_system.CopyFrom(inputs.messages, allocator); + adjusted_messages_val = add_system(messages_copy_for_add_system, system_prompt_str, allocator); + } else { + adjusted_messages_val.CopyFrom(inputs.messages, allocator); + } + + if (adjusted_messages_val.IsArray()){ + for (auto & message_val_mut : adjusted_messages_val.GetArray()) { // Iterate by mutable ref + // message_ is already using 'allocator' as it's part of adjusted_messages_val + rapidjson::Value message; // Create a mutable copy for this iteration + message.CopyFrom(message_val_mut, allocator); + + + if (!message.IsObject() || !message.HasMember("role") || !message.HasMember("content")) { + // MNN_ERROR replacement: + fprintf(stderr, "message must have 'role' and 'content' fields: %s\n", valueToString(message).c_str()); + // Potentially skip this message or handle error + continue; + } + const char* role_cstr = message["role"].GetString(); + std::string role = role_cstr; + + if (message.HasMember("tool_calls")) { + if (polyfill_object_arguments || polyfill_tool_calls) { + if (message["tool_calls"].IsArray()) { + for (auto & tool_call_val : message["tool_calls"].GetArray()) { + if (tool_call_val.IsObject() && tool_call_val.HasMember("type") && tool_call_val["type"] == "function") { + if (tool_call_val.HasMember("function") && tool_call_val["function"].IsObject()) { + auto& function_val = tool_call_val["function"]; + if (function_val.HasMember("arguments") && function_val["arguments"].IsString()) { + std::string args_str = function_val["arguments"].GetString(); + Document args_doc; + if (!args_doc.Parse(args_str.c_str()).HasParseError()) { + // Replace the string arguments with the parsed Value object + // The new Value must use 'allocator' + rapidjson::Value new_args_val; + new_args_val.CopyFrom(args_doc, allocator); + function_val["arguments"].Swap(new_args_val); // Swap to avoid copy if possible + } + } + } + } + } + } + } + if (polyfill_tool_calls) { + rapidjson::Value content_val; content_val.CopyFrom(message["content"], allocator); // Keep original content if any + rapidjson::Value tool_calls_payload(rapidjson::kArrayType); + if (message["tool_calls"].IsArray()) { + for (const auto & tool_call_val_const : message["tool_calls"].GetArray()) { + if (tool_call_val_const.IsObject() && tool_call_val_const.HasMember("type") && tool_call_val_const["type"] == "function") { + const auto& function_val_const = tool_call_val_const["function"]; + rapidjson::Value tc_item(rapidjson::kObjectType); + tc_item.AddMember("name", rapidjson::Value(function_val_const["name"], allocator), allocator); + // Arguments should already be objects if polyfill_object_arguments ran + tc_item.AddMember("arguments", rapidjson::Value(function_val_const["arguments"], allocator), allocator); + if (tool_call_val_const.HasMember("id")) { + tc_item.AddMember("id", rapidjson::Value(tool_call_val_const["id"], allocator), allocator); + } + tool_calls_payload.PushBack(tc_item, allocator); + } + } + } + rapidjson::Value obj_for_content(rapidjson::kObjectType); + obj_for_content.AddMember("tool_calls", tool_calls_payload, allocator); + if (!content_val.IsNull() && !(content_val.IsString() && strlen(content_val.GetString()) == 0)) { + obj_for_content.AddMember("content", content_val, allocator); + } + + // Serialize obj_for_content to string for message["content"] + rapidjson::StringBuffer s_buffer; + rapidjson::PrettyWriter writer_obj(s_buffer); + writer_obj.SetIndent(' ', 2); + obj_for_content.Accept(writer_obj); + message["content"].SetString(s_buffer.GetString(), allocator); + message.RemoveMember("tool_calls"); + } + } + if (polyfill_tool_responses && role == "tool") { + message["role"].SetString("user", allocator); // Change role to user + rapidjson::Value tool_response_obj(rapidjson::kObjectType); + rapidjson::Value tool_response_inner_obj(rapidjson::kObjectType); + + if (message.HasMember("name")) { + tool_response_inner_obj.AddMember("tool", rapidjson::Value(message["name"], allocator), allocator); + } + // message["content"] is guaranteed to exist by check above + tool_response_inner_obj.AddMember("content", rapidjson::Value(message["content"], allocator), allocator); + if (message.HasMember("tool_call_id")) { + tool_response_inner_obj.AddMember("tool_call_id", rapidjson::Value(message["tool_call_id"], allocator), allocator); + } + tool_response_obj.AddMember("tool_response", tool_response_inner_obj, allocator); + + // Serialize tool_response_obj to string for message["content"] + rapidjson::StringBuffer s_buffer_resp; + rapidjson::PrettyWriter writer_resp(s_buffer_resp); + writer_resp.SetIndent(' ',2); + tool_response_obj.Accept(writer_resp); + message["content"].SetString(s_buffer_resp.GetString(), allocator); + + if (message.HasMember("name")) message.RemoveMember("name"); + if (message.HasMember("tool_call_id")) message.RemoveMember("tool_call_id"); // if it was there + } + + if (!message["content"].IsNull() && polyfill_system_role) { + // Assuming content is string after previous polyfills or by its nature + std::string content_str; + if (message["content"].IsString()){ + content_str = message["content"].GetString(); + } else { + // If content is not string (e.g. array for typed content), it needs to be stringified for pending_system + // This case should be handled by typed_content polyfill first if active + // For simplicity, if it's not string here, we might skip or stringify it + rapidjson::StringBuffer temp_s_buffer; + rapidjson::Writer temp_writer(temp_s_buffer); + message["content"].Accept(temp_writer); + content_str = temp_s_buffer.GetString(); + } + + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content_str; + // This message is consumed, skip adding it directly + // A continue here would skip the 'add_message(message)' below for system messages + // which is the desired behavior. + // However, the original code structure adds the modified message (if not system) + // or flushes system messages. + // Let's ensure this message isn't added by 'add_message' if it's system. + // The flush_sys() and add_message(message) logic outside the loop handles it. + // So, if role is system, we just update pending_system and the message itself is not added. + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + std::string new_content = pending_system + (content_str.empty() ? "" : "\n" + content_str); + message["content"].SetString(new_content.c_str(), allocator); + pending_system.clear(); + } + } else { // assistant, tool (already transformed to user) + flush_sys(); + } + } + } + add_message(message); // add_message handles copying to actual_messages with allocator + } + } + flush_sys(); + } else { // no polyfills needed + actual_messages.CopyFrom(inputs.messages, allocator); + } + + auto context = minja::Context::make(nullptr); // nlohmann::json() equivalent for context data + // The make function needs to be adapted for rapidjson::Value + // For now, creating an empty object for context data. + rapidjson::Value context_data_val(rapidjson::kObjectType); + context_data_val.AddMember("messages", actual_messages, allocator); // actual_messages already uses allocator + context_data_val.AddMember("add_generation_prompt", inputs.add_generation_prompt, allocator); + + + // Convert context_data_val to nlohmann::json for minja::Context::make + // This is a temporary bridge. minja::Context itself needs to be updated for rapidjson. + // This is a critical dependency. + + context = minja::Context::make(minja::Value(context_data_val)); + + context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); + context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); + if (opts.define_strftime_now) { + auto time_now_capture = inputs.now; // capture for lambda + context->set("strftime_now", minja::Value::callable([time_now_capture](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); + + auto time_point = std::chrono::system_clock::to_time_t(time_now_capture); + auto local_time = *std::localtime(&time_point); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); + } + + if (!inputs.tools.IsNull()) { + context->set("tools", minja::Value(inputs.tools)); + } + if (!inputs.extra_context.IsNull() && inputs.extra_context.IsObject()) { + for (auto & kv : inputs.extra_context.GetObject()) { + context->set(kv.name.GetString(), minja::Value(kv.value)); + } + } + + auto ret = template_root_->render(context); + return ret; + } + rapidjson::Value chat_template::add_system( + const rapidjson::Value & messages_const, // input messages (const ref) + const std::string & system_prompt, + rapidjson::Document::AllocatorType& allocator) { + rapidjson::Value messages_with_system(rapidjson::kArrayType); + messages_with_system.CopyFrom(messages_const, allocator); // Deep copy to make it modifiable + + if (!messages_with_system.Empty() && messages_with_system[0].IsObject() && + messages_with_system[0].HasMember("role") && messages_with_system[0]["role"] == "system") { + + std::string existing_system_content_str; + if (messages_with_system[0].HasMember("content") && messages_with_system[0]["content"].IsString()) { + existing_system_content_str = messages_with_system[0]["content"].GetString(); + } + + std::string new_content_str = existing_system_content_str + "\n\n" + system_prompt; + messages_with_system[0]["content"].SetString(new_content_str.c_str(), allocator); + + } else { + rapidjson::Value new_system_msg(rapidjson::kObjectType); + new_system_msg.AddMember("role", "system", allocator); + new_system_msg.AddMember("content", rapidjson::StringRef(system_prompt.c_str()), allocator); + + // Insert at the beginning + rapidjson::Value temp_array(rapidjson::kArrayType); + temp_array.PushBack(new_system_msg, allocator); + for (auto& el : messages_with_system.GetArray()) { + rapidjson::Value el_copy; + el_copy.CopyFrom(el, allocator); + temp_array.PushBack(el_copy, allocator); + } + messages_with_system.Swap(temp_array); + } + return messages_with_system; // This Value is allocated with 'allocator' + } +}; +#endif diff --git a/transformers/llm/engine/src/minja/chat_template.hpp b/transformers/llm/engine/src/minja/chat_template.hpp new file mode 100644 index 00000000..14367583 --- /dev/null +++ b/transformers/llm/engine/src/minja/chat_template.hpp @@ -0,0 +1,112 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "rapidjson/document.h" + +// Forward declaration for Value used in Minja +namespace minja { +class Value; +class TemplateNode; +} + +using Document = rapidjson::Document; +// Note: rapidjson::Value is the type for all JSON values (objects, arrays, strings, numbers, booleans, null). +// rapidjson::Document inherits from rapidjson::Value and holds the memory allocations for the DOM. +// We will use rapidjson::Value where nlohmann::json was used for individual values, +// and rapidjson::Document where a new JSON structure was being parsed or built. + +namespace minja { + +struct chat_template_caps { + bool supports_tools = false; + bool supports_tool_calls = false; + bool supports_tool_responses = false; + bool supports_system_role = false; + bool supports_parallel_tool_calls = false; + bool supports_tool_call_id = false; + // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool requires_object_arguments = false; + // CohereForAI/c4ai-command-r-plus simple variant + bool requires_non_null_content = false; + // MiniMaxAI/MiniMax-Text-01 special + bool requires_typed_content = false; +}; + +struct chat_template_inputs { + rapidjson::Document messages; // Should be an array + rapidjson::Document tools; // Should be an array or null + bool add_generation_prompt = true; + rapidjson::Document extra_context; // Should be an object or null + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + rapidjson::Document::AllocatorType* allocator_for_inputs = nullptr; // To be set when creating inputs + + // Default constructor to initialize Value members + chat_template_inputs() : messages(rapidjson::kArrayType), tools(rapidjson::kNullType), extra_context(rapidjson::kNullType) {} +}; + +struct chat_template_options { + bool apply_polyfills = true; + bool use_bos_token = true; + bool use_eos_token = true; + bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; +}; + +class chat_template { + +private: + chat_template_caps caps_; + std::string source_; + std::string bos_token_; + std::string eos_token_; + std::shared_ptr template_root_; + std::string tool_call_example_; + + std::string try_raw_render( + rapidjson::Value& messages, // Modifying to pass by ref as it might be changed by polyfills later + rapidjson::Value& tools, // Modifying to pass by ref + bool add_generation_prompt, + rapidjson::Document::AllocatorType& allocator, // Added allocator + rapidjson::Value extra_context = rapidjson::Value(rapidjson::kNullType)) const; // Default to null +public: + MNN_PUBLIC chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token); + + const std::string & source() const { return source_; } + const std::string & bos_token() const { return bos_token_; } + const std::string & eos_token() const { return eos_token_; } + const chat_template_caps & original_caps() const { return caps_; } + + + MNN_PUBLIC std::string apply( + chat_template_inputs & inputs, + const chat_template_options & opts = chat_template_options()) const; + + static rapidjson::Value add_system( + const rapidjson::Value & messages_const, // input messages (const ref) + const std::string & system_prompt, + rapidjson::Document::AllocatorType& allocator); +}; +}; diff --git a/transformers/llm/engine/src/minja/minja.cpp b/transformers/llm/engine/src/minja/minja.cpp new file mode 100644 index 00000000..9d908ba2 --- /dev/null +++ b/transformers/llm/engine/src/minja/minja.cpp @@ -0,0 +1,1512 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#ifdef LLM_USE_MINJA +#include "minja.hpp" +namespace minja { +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { +public: + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: return "text"; + case Type::Expression: return "expression"; + case Type::If: return "if"; + case Type::Else: return "else"; + case Type::Elif: return "elif"; + case Type::EndIf: return "endif"; + case Type::For: return "for"; + case Type::EndFor: return "endfor"; + case Type::Set: return "set"; + case Type::EndSet: return "endset"; + case Type::Comment: return "comment"; + case Type::Macro: return "macro"; + case Type::EndMacro: return "endmacro"; + case Type::Filter: return "filter"; + case Type::EndFilter: return "endfilter"; + case Type::Generation: return "generation"; + case Type::EndGeneration: return "endgeneration"; + case Type::Break: return "break"; + case Type::Continue: return "continue"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, loc, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::shared_ptr expr; + ExpressionTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, loc, pre, post), expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::shared_ptr condition; + IfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, loc, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::shared_ptr condition; + ElifTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, loc, pre, post), condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, loc, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, loc, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::shared_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, loc, pre, post), name(std::move(n)), params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, loc, pre, post) {} +}; + +struct FilterTemplateToken : public TemplateToken { + std::shared_ptr filter; + FilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, loc, pre, post), filter(std::move(filter)) {} +}; + +struct EndFilterTemplateToken : public TemplateToken { + EndFilterTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, loc, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + bool recursive; + ForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + std::shared_ptr && c, bool r) + : TemplateToken(Type::For, loc, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, loc, pre, post) {} +}; + +struct GenerationTemplateToken : public TemplateToken { + GenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, loc, pre, post) {} +}; + +struct EndGenerationTemplateToken : public TemplateToken { + EndGenerationTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, loc, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::shared_ptr value; + SetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, loc, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, loc, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, loc, pre, post), text(t) {} +}; + + + +struct LoopControlTemplateToken : public TemplateToken { + LoopControlType control_type; + LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} +}; + class Parser { + private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { + if (!template_str) _printlog("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) ++it; + } + return true; + } + + std::shared_ptr parseString() { + auto doParse = [&](char quote) -> std::shared_ptr { + if (it == end || *it != quote) return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': result += '\n'; break; + case 'r': result += '\r'; break; + case 't': result += '\t'; break; + case 'b': result += '\b'; break; + case 'f': result += '\f'; break; + case '\\': result += '\\'; break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + std::shared_ptr res(new std::string); + *res = result; + return res; + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"') return doParse('"'); + if (*it == '\'') return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) { + _printlog("Multiple decimal points"); + return json(); + } + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) { + _printlog("Multiple exponents"); + return json(); + } + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + std::string str(start, it); + if (hasExponent || hasDecimal) { + double v = std::stof(str); + return json(v); + } + int64_t v = std::stoi(str); + return json(v); + } + + /** integer, float, bool, string */ + std::shared_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) return std::make_shared(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") return std::make_shared(true); + if (token == "false" || token == "False") return std::make_shared(false); + if (token == "None") return std::make_shared(nullptr); + _printlog("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) return std::make_shared(number); + + it = start; + return nullptr; + } + + bool peekSymbols(const std::vector & symbols) const { + for (const auto & symbol : symbols) { + if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::shared_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) return left; + + if (!allow_if_expr) return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto cepair = parseIfExpression(); + auto condition = cepair.first; + auto else_expr = cepair.second; + return std::make_shared(location, std::move(condition), std::move(left), std::move(else_expr)); + } + + Location get_location() const { + return {template_str, (size_t) std::distance(start, it)}; + } + + std::pair, std::shared_ptr> parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) _printlog("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::shared_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) _printlog("Expected 'else' expression"); + } + return std::make_pair(std::move(condition), std::move(else_expr)); + } + + std::shared_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) _printlog("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) _printlog("Expected right side of 'or' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::shared_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) _printlog("Expected expression after 'not' keyword"); + return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::shared_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) _printlog("Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) _printlog("Expected right side of 'and' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::shared_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) _printlog("Expected left side of 'logical compare' expression"); + + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) _printlog("Expected identifier after 'is' keyword"); + + return std::make_shared( + left->location, + std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) _printlog("Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; + else _printlog("Unknown comparison operator: " + op_str); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) _printlog("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) _printlog("Expected expression in call args"); + if (expr->mType == Expression::Type_Variable) { + auto ident = (VariableExpr*)(expr.get()); + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) _printlog("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + _printlog("Expected closing parenthesis in call args"); + } + return result; + } + } + _printlog("Expected closing parenthesis in call args"); + return result; + } + + ArgumentsExpression parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) _printlog("Expected opening parenthesis in call args"); + + ArgumentsExpression result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) _printlog("Expected expression in call args"); + + if (expr->mType == Expression::Type_Variable) { + auto ident = (VariableExpr*)(expr.get()); + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) _printlog("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + _printlog("Expected closing parenthesis in call args"); + } + return result; + } + } + _printlog("Expected closing parenthesis in call args"); + return result; + } + + std::shared_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return std::make_shared(location, ident); + } + + std::shared_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) _printlog("Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) _printlog("Expected right side of 'string concat' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::shared_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) _printlog("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) _printlog("Expected right side of 'math pow' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::shared_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) _printlog("Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) _printlog("Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + std::shared_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) _printlog("Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) _printlog("Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (expr->mType == Expression::Type_Filter) { + auto filter = (FilterExpr*)(expr.get()); + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return std::make_shared(get_location(), std::move(parts)); + } + } + return left; + } + + std::shared_ptr call_func(const std::string & name, ArgumentsExpression && args) const { + return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); + } + + std::shared_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseExpansion(); + if (!expr) _printlog("Expected expr of 'unary plus/minus/expansion' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return std::make_shared(get_location(), std::move(expr), op); + } + return expr; + } + + std::shared_ptr parseExpansion() { + static std::regex expansion_tok(R"(\*\*?)"); + auto op_str = consumeToken(expansion_tok); + auto expr = parseValueExpression(); + if (op_str.empty()) return expr; + if (!expr) { + _printlog("Expected expr of 'expansion' expression"); + return nullptr; + } + return std::make_shared(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict); + } + + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) return std::make_shared(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) return braced; + + auto array = parseArray(); + if (array) return array; + + auto dictionary = parseDictionary(); + if (dictionary) return dictionary; + + _printlog("Expected value expression"); + return nullptr; + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { + if (!consumeToken("[").empty()) { + std::shared_ptr index; + auto slice_loc = get_location(); + std::shared_ptr start, end, step; + bool c1 = false, c2 = false; + + if (!peekSymbols({ ":" })) { + start = parseExpression(); + } + + if (!consumeToken(":").empty()) { + c1 = true; + if (!peekSymbols({ ":", "]" })) { + end = parseExpression(); + } + if (!consumeToken(":").empty()) { + c2 = true; + if (!peekSymbols({ "]" })) { + step = parseExpression(); + } + } + } + + if ((c1 || c2) && (start || end || step)) { + index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); + } else { + index = std::move(start); + } + if (!index) { + MNN_ERROR("Empty index in subscript"); + } + if (consumeToken("]").empty()) { + MNN_ERROR("Expected closing bracket in subscript"); + } + + value = std::make_shared(value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) _printlog("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({ "(" })) { + auto callParams = parseCallArgs(); + value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + } else { + auto key = std::make_shared(identifier->location, Value(identifier->get_name())); + value = std::make_shared(identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({ "(" })) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = std::make_shared(location, std::move(value), std::move(callParams)); + } + return value; + } + + std::shared_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) return nullptr; + + auto expr = parseExpression(); + if (!expr) _printlog("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) _printlog("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) _printlog("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return std::make_shared(get_location(), std::move(tuple)); + } + } + _printlog("Expected closing parenthesis"); + return nullptr; + } + + std::shared_ptr parseArray() { + if (consumeToken("[").empty()) return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) _printlog("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) _printlog("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + _printlog("Expected comma or closing bracket in array"); + } + } + _printlog("Expected closing bracket"); + return nullptr; + } + + std::shared_ptr parseDictionary() { + if (consumeToken("{").empty()) return nullptr; + + std::vector, std::shared_ptr>> elements; + if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) _printlog("Expected key in dictionary"); + if (consumeToken(":").empty()) _printlog("Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) _printlog("Expected value in dictionary"); + elements.emplace_back(std::make_pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + _printlog("Expected comma or closing brace in dictionary"); + } + } + _printlog("Expected closing brace"); + return nullptr; + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) _printlog("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::string unexpected(const TemplateToken & token) const { + return std::string("Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::string unterminated(const TemplateToken & token) const { + return std::string("Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); + static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); + static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); + static std::regex block_close_regex(R"(\s*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + std::smatch match; + + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(std::make_shared(location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + _printlog("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(std::make_shared(location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) _printlog("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) _printlog("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) _printlog("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) _printlog("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) _printlog("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) _printlog("Expected iterable in for block"); + + std::shared_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "generation") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "endgeneration") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); + + std::string ns; + std::vector var_names; + std::shared_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) _printlog("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) _printlog("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) _printlog("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, ns, var_names, std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) _printlog("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, std::move(macroname), std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "filter") { + auto filter = parseExpression(); + if (!filter) _printlog("Expected expression in filter block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, std::move(filter))); + } else if (keyword == "endfilter") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space)); + } else if (keyword == "break" || keyword == "continue") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_shared(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue)); + } else { + _printlog("Unexpected block: " + keyword); + } + } else if (std::regex_search(it, end, match, non_text_open_regex)) { + if (!match.position()) { + if (match[0] != "{#") + _printlog("Internal error: Expected a comment"); + _printlog("Missing end of comment tag"); + } + auto text_end = it + match.position(); + text = std::string(it, text_end); + it = text_end; + tokens.push_back(std::make_shared(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + text = std::string(it, end); + it = end; + tokens.push_back(std::make_shared(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } + } + return tokens; + } + + std::shared_ptr parseTemplate( + const TemplateTokenIterator & begin, + TemplateTokenIterator & it, + const TemplateTokenIterator & end, + bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto & token = *(it++); + if (token->type == TemplateToken::Type::If) { + auto if_token = (IfTemplateToken*)(token.get()); + std::vector, std::shared_ptr>> cascade; + cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = (ElifTemplateToken*)((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + MNN_ERROR("%s\n", unterminated(**start).c_str()); + } + children.emplace_back(std::make_shared(token->location, std::move(cascade))); + } else if (token->type == TemplateToken::Type::For) { + auto for_token = (ForTemplateToken*)(token.get()); + auto body = parseTemplate(begin, it, end); + auto else_body = std::shared_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + MNN_ERROR("%s\n", unterminated(**start).c_str()); + } + children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if(token->type == TemplateToken::Type::Generation) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) { + MNN_ERROR("%s\n", unterminated(**start).c_str()); + } + // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking). + children.emplace_back(std::move(body)); + } else if(token->type == TemplateToken::Type::Text) { + auto text_token = (TextTemplateToken*)(token.get()); + SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"(\s+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + auto i = text.size(); + while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--; + if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) { + text.resize(i); + } + } + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^\s+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && (*(it - 2))->type != TemplateToken::Type::Expression) { + if (!text.empty() && text[0] == '\n') { + text.erase(0, 1); + } + } + if (it == end && !options.keep_trailing_newline) { + auto i = text.size(); + if (i > 0 && text[i - 1] == '\n') { + i--; + if (i > 0 && text[i - 1] == '\r') i--; + text.resize(i); + } + } + children.emplace_back(std::make_shared(token->location, text)); + } else if(token->type == TemplateToken::Type::Expression) { + auto expr_token = (ExpressionTemplateToken*)(token.get()); + children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); + } else if(token->type == TemplateToken::Type::Set) { + auto set_token = (SetTemplateToken*)(token.get()); + if (set_token->value) { + children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + MNN_ERROR("%s\n", unterminated(**start).c_str()); + } + if (!set_token->ns.empty()) _printlog("Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) _printlog("Structural assignment not supported in set with template value"); + auto & name = set_token->var_names[0]; + children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); + } + } else if(token->type == TemplateToken::Type::Macro) { + auto macro_token = (MacroTemplateToken*)(token.get()); + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + MNN_ERROR("%s\n", unterminated(**start).c_str()); + } + children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if(token->type == TemplateToken::Type::Filter) { + auto filter_token = (FilterTemplateToken*)(token.get()); + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { + MNN_ERROR("%s\n", unterminated(**start).c_str()); + } + children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); + } else if(token->type == TemplateToken::Type::Comment) { + // Ignore comments + } else if(token->type == TemplateToken::Type::Break) { + auto ctrl_token = (LoopControlTemplateToken*)(token.get()); + children.emplace_back(std::make_shared(token->location, ctrl_token->control_type)); + } else { + bool needBreak = false; + switch (token->type) { + case TemplateToken::Type::EndSet: + case TemplateToken::Type::EndFor: + case TemplateToken::Type::EndMacro: + case TemplateToken::Type::EndFilter: + case TemplateToken::Type::EndIf: + case TemplateToken::Type::Else: + case TemplateToken::Type::Elif: + case TemplateToken::Type::EndGeneration: + it--; + needBreak = true; + break; + default: + MNN_ERROR("%s\n", unexpected(**(it-1)).c_str()); + } + if (needBreak) { + break; + } + } + } + if (fully && it != end) { + MNN_ERROR("%s\n", unexpected(**it).c_str()); + } + if (children.empty()) { + return std::make_shared(Location { template_str, 0 }, std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return std::make_shared(children[0]->location(), std::move(children)); + } + } + + public: + + static std::shared_ptr parse(const std::string& template_str, const Options & options) { + Parser parser(std::make_shared(normalize_newlines(template_str)), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* fully= */ true); + } + }; + std::shared_ptr parse(const std::string& template_str, const Options & options) { + return Parser::parse(template_str, options); + } + + std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + // globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { + // _printlog(args.at("message").get()); + // })); + globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { + return Value(args.at("value").dump(args.get("indent", -1), /* to_json= */ true)); + })); + globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { + auto items = Value::array(); + if (args.contains("object")) { + auto & obj = args.at("object"); + if (obj.is_string()) { + rapidjson::Document doc; + doc.Parse(obj.get().c_str()); + for (auto& kv : doc.GetObject()) { + items.push_back(Value::array({kv.name, kv.value})); + } + } else if (!obj.is_null()) { + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { + auto items = args.at("items"); + if (!items.is_array()) _printlog("object is not a list"); + if (items.empty()) return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { + auto & text = args.at("text"); + return text.is_null() ? text : Value(strip(text.get())); + })); + auto char_transform_function = [](const std::string & name, const std::function & fn) { + return simple_function(name, { "text" }, [=](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), fn); + return Value(res); + }); + }; + globals.set("lower", char_transform_function("lower", ::tolower)); + globals.set("upper", char_transform_function("upper", ::toupper)); + globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + args.expectArgs("default", {2, 3}, {0, 1}); + auto & value = args.args[0]; + auto & default_value = args.args[1]; + bool boolean = false; + if (args.args.size() == 3) { + boolean = args.args[2].get(); + } else { + Value bv = args.get_named("boolean"); + if (!bv.is_null()) { + boolean = bv.get(); + } + } + return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value; + })); + auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { + return Value((int64_t) args.at("items").size()); + })); + globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { + if (args.size() != 1) _printlog("dictsort expects exactly 1 argument (TODO: fix implementation)"); + auto & value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto & key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { + auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) _printlog("object is not iterable: " + items.dump()); + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) first = false; + else oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto & items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { + auto & items = args.at("items"); + if (!items.to_bool() || !items.is_array()) _printlog("join expects an array for items, got: " + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); + for (auto & iter : args.kwargs) { + auto& name = iter.first; + auto& value = iter.second; + ns.set(name, value); + } + return ns; + })); + auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + return (int64_t) items.size(); + })); + globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_int(); + })); + globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) _printlog("object is not iterable"); + return items; + })); + globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) _printlog("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value & filter, Value & extra_args) -> Value { + return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { + auto & value = args.at("value"); + ArgumentsValue actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + auto select_or_reject = [make_filter](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) { + return Value::array(); + } + if (!items.is_array()) { + _printlog("object is not iterable: " + items.dump()); + } + + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) { + _printlog("Undefined filter: " + args.args[1].dump()); + } + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (pred_res.to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } + return res; + }); + }; + globals.set("select", select_or_reject(/* is_select= */ true)); + globals.set("reject", select_or_reject(/* is_select= */ false)); + globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { + auto & items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) _printlog("Undefined filter: " + args.args[1].dump()); + ArgumentsValue filter_args { {Value()}, {} }; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto & item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + _printlog("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text").get(); + auto first = args.get("first", false); + std::string out; + std::string indent(args.get("indent", 0), ' '); + std::istringstream iss(text); + std::string line; + auto is_first = true; + while (std::getline(iss, line, '\n')) { + auto needs_indent = !is_first || first; + if (is_first) is_first = false; + else out += "\n"; + if (needs_indent) out += indent; + out += line; + } + if (!text.empty() && text.back() == '\n') out += "\n"; + return out; + })); + auto select_or_reject_attr = [](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) _printlog("object is not iterable: " + items.dump()); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) _printlog("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + }); + }; + globals.set("selectattr", select_or_reject_attr(/* is_select= */ true)); + globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false)); + globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto & arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto & iter : args.kwargs) { + auto& name = iter.first; + auto& value = iter.second; + size_t i; + if (name == "start") { + i = 0; + } else if (name == "end") { + i = 1; + } else if (name == "step") { + i = 2; + } else { + _printlog("Unknown argument " + name + " for function range"); + } + + if (param_set[i]) { + _printlog("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; + } + if (!param_set[1]) { + _printlog("Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); + } + + +}; + +#endif diff --git a/transformers/llm/engine/src/minja/minja.hpp b/transformers/llm/engine/src/minja/minja.hpp new file mode 100644 index 00000000..764b9808 --- /dev/null +++ b/transformers/llm/engine/src/minja/minja.hpp @@ -0,0 +1,1670 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "rapidjson/document.h" +#include "rapidjson/prettywriter.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/error/en.h" // For GetParseError_En + +#include +//#define MNN_OPEN_TIME_TRACE +#include + +static void _printlog(const std::string& i) { + MNN_PRINT("%s\n", i.c_str()); +} + + +namespace minja { +enum ObjectType { + JSON_NULL = 0, + JSON_INT64 = 1, + JSON_DOUBLE = 2, + JSON_STRING = 3, + JSON_BOOL = 4, +}; +class json { +public: + ObjectType mType; + std::string mString; + int64_t mInt; + double mDouble; + bool mBool; + json() { + mType = JSON_NULL; + } + json(bool v) { + mBool = v; + mType = JSON_BOOL; + } + json(int64_t v) { + mInt = v; + mType = JSON_INT64; + } + json(double v) { + mDouble = v; + mType = JSON_DOUBLE; + } + json(std::string v) { + mString = v; + mType = JSON_STRING; + } + json(const char* c, size_t len) { + mString.assign(c, len); + mType = JSON_STRING; + } + json(const json& right) { + mType = right.mType; + mInt = right.mInt; + mBool = right.mBool; + mDouble = right.mDouble; + mString = right.mString; + } + bool operator==(const json& right) const { + if (mType != right.mType) { + return false; + } + switch (mType) { + case JSON_STRING: + return mString == right.mString; + case JSON_INT64: + return mInt == right.mInt; + case JSON_DOUBLE: + return mDouble == right.mDouble; + case JSON_BOOL: + return mBool == right.mBool; + default: + break; + } + return true; + } + bool is_string() const { + return mType == JSON_STRING; + } + bool is_null() const { return mType == JSON_NULL; } + bool is_boolean() const { return mType == JSON_BOOL; } + bool is_number_integer() const { return mType == JSON_INT64; } + bool is_number_float() const { return mType == JSON_DOUBLE; } + bool is_number() const { return mType == JSON_DOUBLE || mType == JSON_INT64; } + bool empty() const { + return mString.empty(); + } + bool get(bool& ) const { + return mBool; + } + std::string get(std::string& ) const { + return mString; + } + int64_t get(int64_t& ) const { + return mInt; + } + int get(int& ) const { + return mInt; + } + float get(float&) const { + return mDouble; + } + double get(double& ) const { + return mDouble; + } + + std::string dump() const { + switch (mType) { + case JSON_STRING: + return mString; + case JSON_INT64: + return std::to_string(mInt); + case JSON_DOUBLE: + return std::to_string(mDouble); + case JSON_BOOL: + return mBool ? "True" : "False"; + default: + break; + } + return "null"; + } +}; + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +struct ArgumentsValue; + +inline std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { +public: + using CallableType = std::function &, ArgumentsValue &)>; + using FilterType = std::function &, ArgumentsValue &)>; + +private: + using ObjectType = std::map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr & array) : array_(array) {} + Value(const std::shared_ptr & object) : object_(object) {} + Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const std::string& s, std::ostringstream & out, char string_quote = '\'') { + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << "\n"; + for (int i = 0, n = level * indent; i < n; ++i) out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) out << ' '; + else print_indent(level + 1); + }; + + auto string_quote = to_json ? '"' : '\''; + + if (is_null()) out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { + if (it != begin) print_sub_sep(); + dump_string(it->first, out, string_quote); + out << ": "; + it->second.dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "}"; + } else if (callable_) { + _printlog("Cannot dump callable to JSON"); + } else if (is_boolean() && !to_json) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string() && !to_json) { + dump_string(primitive_.mString, out, string_quote); + } else { + out << primitive_.dump(); + } + } + +public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t & v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const std::nullptr_t &) {} + Value(const std::string & v) : primitive_(v) {} + Value(const char * v) : primitive_(std::string(v)) {} + Value(json& json) : primitive_(json) { + // Do nothing + } + + Value(const rapidjson::Value & v) { + if (v.IsObject()) { + auto object = std::make_shared(); + for (auto& it : v.GetObject()) { + (*object)[it.name.GetString()] = it.value; + } + object_ = std::move(object); + } else if (v.IsArray()) { + auto array = std::make_shared(); + for (const auto& item : v.GetArray()) { + array->push_back(Value(item)); + } + array_ = array; + } else { + if (v.IsFloat() || v.IsDouble()) { + primitive_ = json(v.GetDouble()); + } else if (v.IsInt() || v.IsInt64()) { + primitive_ = json(v.GetInt64()); + } else if (v.IsBool()) { + primitive_ = json(v.GetBool()); + } else if (v.IsString()) { + primitive_ = json(v.GetString(), v.GetStringLength()); + } + } + } + + std::vector keys() { + if (!object_) _printlog("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) return object_->size(); + if (is_array()) return array_->size(); + if (is_string()) return primitive_.mString.size(); + _printlog("Value is not an array or object: " + dump()); + return 0; + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType & callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + _printlog("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + _printlog("Value is not an array: " + dump()); + array_->push_back(v); + } + Value pop(const Value& index) { + if (is_array()) { + if (array_->empty()) + _printlog("pop from empty list"); + if (index.is_null()) { + auto ret = array_->back(); + array_->pop_back(); + return ret; + } else if (!index.is_number_integer()) { + _printlog("pop index must be an integer: " + index.dump()); + } else { + auto i = index.get(); + if (i < 0 || i >= static_cast(array_->size())) + _printlog("pop index out of range: " + index.dump()); + auto it = array_->begin() + (i < 0 ? array_->size() + i : i); + auto ret = *it; + array_->erase(it); + return ret; + } + } else if (is_object()) { + if (!index.is_hashable()) + _printlog("Unhashable type: " + index.dump()); + auto it = object_->find(index.primitive_.dump()); + if (it == object_->end()) + _printlog("Key not found: " + index.dump()); + auto ret = it->second; + object_->erase(it); + return ret; + } else { + _printlog("Value is not an array or object: " + dump()); + } + return Value(); + } + Value get(const Value& key) { + if (array_) { + if (!key.is_number_integer()) { + return Value(); + } + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) _printlog("Unhashable type: " + dump()); + auto it = object_->find(key.primitive_.dump()); + if (it == object_->end()) return Value(); + return it->second; + } + return Value(); + } + void set(const std::string& key, const Value& value) { + if (!object_) { + _printlog("Value is not an object: " + dump()); + return; + } + (*object_)[key] = value; + } + Value call(const std::shared_ptr & context, ArgumentsValue & args) const { + if (!callable_) { +// _printlog("Value is not callable: " + dump()); + return Value(); + } + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + bool is_iterable() const { return is_array() || is_object() || is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + _printlog("Undefined value or reference"); + if (is_string()) return primitive_.empty(); + if (is_array()) return array_->empty(); + if (is_object()) return object_->empty(); + return false; + } + + void for_each(const std::function & callback) const { + if (is_null()) + _printlog("Undefined value or reference"); + if (array_) { + for (auto& item : *array_) { + callback(item); + } + } else if (object_) { + for (auto & item : *object_) { + Value key(item.first); + callback(key); + } + } else if (is_string()) { + for (char c : primitive_.mString) { + auto val = Value(std::string(1, c)); + callback(val); + } + } else { + _printlog("Value is not iterable: " + dump()); + } + } + + bool to_bool() const { + if (is_null()) return false; + if (is_boolean()) return get(); + if (is_number()) return get() != 0; + if (is_string()) return !get().empty(); + if (is_array()) return !empty(); + return true; + } + + int64_t to_int() const { + if (is_null()) return 0; + if (is_boolean()) return get() ? 1 : 0; + if (is_number()) return static_cast(get()); + if (is_string()) { + return std::stol(get()); + } + return 0; + } + + bool operator<(const Value & other) const { + if (is_null()) { + _printlog("Undefined value or reference"); + return false; + } + if (is_number() && other.is_number()) return get() < other.get(); + if (is_string() && other.is_string()) return get() < other.get(); + _printlog("Cannot compare values: " + dump() + " < " + other.dump()); + return false; + } + bool operator>=(const Value & other) const { return !(*this < other); } + + bool operator>(const Value & other) const { + if (is_null()) { + _printlog("Undefined value or reference"); + return false; + } + if (is_number() && other.is_number()) return get() > other.get(); + if (is_string() && other.is_string()) return get() > other.get(); + _printlog("Cannot compare values: " + dump() + " > " + other.dump()); + return false; + } + bool operator<=(const Value & other) const { return !(*this > other); } + + bool operator==(const Value & other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) return false; + } + if (array_) { + if (!other.array_) return false; + if (array_->size() != other.array_->size()) return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; + } + return true; + } else if (object_) { + if (!other.object_) return false; + if (object_->size() != other.object_->size()) return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value & other) const { return !(*this == other); } + + bool contains(const char * key) const { return contains(std::string(key)); } + bool contains(const std::string & key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + _printlog("contains can only be called on arrays and objects: " + dump()); + } + return false; + } + bool contains(const Value & value) const { + if (is_null()) + _printlog("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) _printlog("Unhashable type: " + value.dump()); + return object_->find(value.primitive_.dump()) != object_->end(); + } else { + _printlog("contains can only be called on arrays and objects: " + dump()); + } + return false; + } + void erase(size_t index) { + if (!array_) _printlog("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string & key) { + if (!object_) _printlog("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value & index) const { + return const_cast(this)->at(index); + } + Value& at(const Value & index) { + if (!index.is_hashable()) { + _printlog("Unhashable type: " + dump()); + } + if (is_array()) return array_->at(index.get()); + if (is_object()) return object_->at(index.primitive_.dump()); + _printlog("Value is not an array or object: " + dump()); + return object_->at(index.primitive_.dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) { + _printlog("Undefined value or reference"); + } + if (is_array()) { + return array_->at(index); + } + if (is_object()) { + return object_->at(std::to_string(index)); + } + _printlog("Value is not an array or object: " + dump()); + return array_->at(index); + } + + template + T get(const std::string & key, T default_value) const { + if (!contains(key)) return default_value; + return at(key).get(); + } + + template + T get() const { + T d; + return primitive_.get(d); + } + + std::string dump(int indent=-1, bool to_json=false) const { + std::ostringstream out; + dump(out, indent, 0, to_json); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) return get(); + if (is_number_integer()) return std::to_string(get()); + if (is_number_float()) return std::to_string(get()); + if (is_boolean()) return get() ? "True" : "False"; + if (is_null()) return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) { + return to_str() + rhs.to_str(); + } else if (is_number_integer() && rhs.is_number_integer()) { + return get() + rhs.get(); + } else if (is_array() && rhs.is_array()) { + auto res = Value::array(); + for (const auto& item : *array_) res.push_back(item); + for (const auto& item : *rhs.array_) res.push_back(item); + return res; + } else { + return get() + rhs.get(); + } + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int64_t i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } + else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +struct ArgumentsValue { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return true; + } + return false; + } + + Value get_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) { + return p.second; + } + } + return Value(); + } + + bool empty() { + return args.empty() && kwargs.empty(); + } + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + _printlog(out.str()); + } + } +}; + +} // namespace minja + +namespace std { + template <> + struct hash { + size_t operator()(const minja::Value & v) const { + if (!v.is_hashable()) + _printlog("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.dump()); + } + }; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string & source, size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^\n"; + if (line < max_line) out << get_line(line + 1) << "\n"; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; + public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) _printlog("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + _printlog("Undefined variable: " + key.dump()); + return values_.at(key); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const std::string & key, const Value & value) { + values_.set(key, value); + } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { +protected: + virtual Value do_evaluate(const std::shared_ptr & context) const = 0; +public: + enum Type { + Type_Variable = 0, + Type_If, + Type_Liter, + Type_Array, + Type_Dict, + Type_Slice, + Type_Subscript, + Type_Unary, + Type_Binary, + Type_MethodCall, + Type_Call, + Type_Filter, + }; + using Parameters = std::vector>>; + + Location location; + const int mType; + + Expression(const Location & location, int type) : location(location), mType(type) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr & context) const { + return do_evaluate(context); + } +}; + +class VariableExpr : public Expression { + std::string name; +public: + VariableExpr(const Location & loc, const std::string& n) + : Expression(loc, Expression::Type_Variable), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr & context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { + if (var_names.size() == 1) { + context->set(var_names[0], item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + _printlog("Mismatched number of variables and items in destructuring assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} +enum class LoopControlType { Normal, Break, Continue}; + +class TemplateNode { + Location location_; +protected: + virtual LoopControlType do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; + +public: + TemplateNode(const Location & location) : location_(location) {} + LoopControlType render(std::ostringstream & out, const std::shared_ptr & context) const { + return do_render(out, context); + } + const Location & location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr & context) const { + std::ostringstream out; + render(out, context); + return out.str(); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; +public: + SequenceNode(const Location & loc, std::vector> && c) + : TemplateNode(loc), children(std::move(c)) {} + LoopControlType do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& child : children) { + auto type = child->render(out, context); + if (LoopControlType::Normal != type) { + return type; + } + } + return LoopControlType::Normal; + } +}; + +class TextNode : public TemplateNode { + std::string text; +public: + TextNode(const Location & loc, const std::string& t) : TemplateNode(loc), text(t) {} + LoopControlType do_render(std::ostringstream & out, const std::shared_ptr &) const override { + out << text; + return LoopControlType::Normal; + } +}; + +class ExpressionNode : public TemplateNode { + std::shared_ptr expr; +public: + ExpressionNode(const Location & loc, std::shared_ptr && e) : TemplateNode(loc), expr(std::move(e)) {} + LoopControlType do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) _printlog("ExpressionNode.expr is null"); + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + return LoopControlType::Normal; + } +}; + +class IfNode : public TemplateNode { + std::vector, std::shared_ptr>> cascade; +public: + IfNode(const Location & loc, std::vector, std::shared_ptr>> && c) + : TemplateNode(loc), cascade(std::move(c)) {} + LoopControlType do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + if (!branch.second) _printlog("IfNode.cascade.second is null"); + return branch.second->render(out, context); + } + } + return LoopControlType::Normal; + } +}; + +class LoopControlNode : public TemplateNode { + LoopControlType control_type_; + public: + LoopControlNode(const Location & loc, LoopControlType control_type) : TemplateNode(loc), control_type_(control_type) {} + LoopControlType do_render(std::ostringstream &, const std::shared_ptr &) const override { + return control_type_; + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; + bool recursive; + std::shared_ptr else_body; +public: + ForNode(const Location & loc, std::vector && var_names, std::shared_ptr && iterable, + std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) + : TemplateNode(loc), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + + LoopControlType do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) _printlog("ForNode.iterable is null"); + if (!body) _printlog("ForNode.body is null"); + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_iterable()) { + _printlog("For loop iterable must be iterable: " + iterable_value.dump()); + } + iterable_value.for_each([&](Value & item) { + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + }); + } + if (filtered_items.empty()) { + if (else_body) { + auto loopcode = else_body->render(out, context); + if (loopcode != LoopControlType::Normal) { + return loopcode; + } + } + } else { + auto loop = recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t) filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.empty() || !args.kwargs.empty()) { + _printlog("cycle() expects at least 1 positional argument and no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto & item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t) i + 1); + loop.set("index0", (int64_t) i); + loop.set("revindex", (int64_t) (n - i)); + loop.set("revindex0", (int64_t) (n - i - 1)); + loop.set("length", (int64_t) n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + auto control_type = body->render(out, loop_context); + if (control_type == LoopControlType::Break) break; + if (control_type == LoopControlType::Continue) continue; + } + } + return LoopControlType::Normal; + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { + _printlog("loop() expects exactly 1 positional iterable argument"); + } + auto & items = args.args[0]; + auto code = visit(items); + return Value(); + }; + } + + return visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::shared_ptr name; + Expression::Parameters params; + std::shared_ptr body; + std::unordered_map named_param_positions; +public: + MacroNode(const Location & loc, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(loc), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto & name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + LoopControlType do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + if (!name) _printlog("MacroNode.name is null"); + if (!body) _printlog("MacroNode.body is null"); + auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i >= params.size()) _printlog("Too many positional arguments for macro " + name->get_name()); + param_set[i] = true; + auto & param_name = params[i].first; + call_context->set(param_name, arg); + } + for (auto& iter : args.kwargs) { + auto& arg_name = iter.first; + auto& value = iter.second; + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) _printlog("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, value); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + return LoopControlType::Normal; + } +}; + +class FilterNode : public TemplateNode { + std::shared_ptr filter; + std::shared_ptr body; + +public: + FilterNode(const Location & loc, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(loc), filter(std::move(f)), body(std::move(b)) {} + + LoopControlType do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!filter) _printlog("FilterNode.filter is null"); + if (!body) _printlog("FilterNode.body is null"); + auto filter_value = filter->evaluate(context); + if (!filter_value.is_callable()) { + _printlog("Filter must be a callable: " + filter_value.dump()); + } + std::string rendered_body = body->render(context); + + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; + auto result = filter_value.call(context, filter_args); + out << result.to_str(); + return LoopControlType::Normal; + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::shared_ptr value; +public: + SetNode(const Location & loc, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(loc), ns(ns), var_names(vns), value(std::move(v)) {} + LoopControlType do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!value) _printlog("SetNode.value is null"); + if (!ns.empty()) { + if (var_names.size() != 1) { + _printlog("Namespaced set only supports a single variable name"); + } + auto & name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) _printlog("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + return LoopControlType::Normal; + + } +}; + +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; +public: + SetTemplateNode(const Location & loc, const std::string & name, std::shared_ptr && tv) + : TemplateNode(loc), name(name), template_value(std::move(tv)) {} + LoopControlType do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!template_value) _printlog("SetTemplateNode.template_value is null"); + Value value { template_value->render(context) }; + context->set(name, value); + return LoopControlType::Normal; + + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; +public: + IfExpr(const Location & loc, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(loc, Expression::Type_If), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!condition) _printlog("IfExpr.condition is null"); + if (!then_expr) _printlog("IfExpr.then_expr is null"); + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; +public: + LiteralExpr(const Location & loc, const Value& v) + : Expression(loc, Expression::Type_Liter), value(v) {} + Value do_evaluate(const std::shared_ptr &) const override { return value; } +}; + +class ArrayExpr : public Expression { + std::vector> elements; +public: + ArrayExpr(const Location & loc, std::vector> && e) + : Expression(loc, Expression::Type_Array), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + if (!e) _printlog("Array element is null"); + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector, std::shared_ptr>> elements; +public: + DictExpr(const Location & loc, std::vector, std::shared_ptr>> && e) + : Expression(loc, Expression::Type_Dict), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::object(); + for (const auto& iter : elements) { + const auto& key = iter.first; + const auto& value = iter.second; + if (!key) _printlog("Dict key is null"); + if (!value) _printlog("Dict value is null"); + result.set(key->evaluate(context).to_str(), value->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { +public: + std::shared_ptr start, end, step; + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e, std::shared_ptr && st = nullptr) + : Expression(loc, Expression::Type_Slice), start(std::move(s)), end(std::move(e)), step(std::move(st)) {} + + Value do_evaluate(const std::shared_ptr &) const override { + _printlog("SliceExpr not implemented"); + return Value(); + } +}; + +class SubscriptExpr : public Expression { + std::shared_ptr base; + std::shared_ptr index; +public: + SubscriptExpr(const Location & loc, std::shared_ptr && b, std::shared_ptr && i) + : Expression(loc, Expression::Type_Subscript), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto target_value = base->evaluate(context); + if (index->mType == Expression::Type_Slice){ + auto slice = (SliceExpr*)(index.get()); + bool reverse = slice->step && slice->step->evaluate(context).get() == -1; + if (slice->step && !reverse) { + MNN_ERROR("Slicing with step other than -1 is not supported"); + } + + int64_t start = slice->start ? slice->start->evaluate(context).get() : (reverse ? target_value.size() - 1 : 0); + int64_t end = slice->end ? slice->end->evaluate(context).get() : (reverse ? -1 : target_value.size()); + + size_t len = target_value.size(); + + if (slice->start && start < 0) { + start = (int64_t)len + start; + } + if (slice->end && end < 0) { + end = (int64_t)len + end; + } + if (target_value.is_string()) { + std::string s = target_value.get(); + + std::string result_str; + if (reverse) { + for (int64_t i = start; i > end; --i) { + if (i >= 0 && i < (int64_t)len) { + result_str += s[i]; + } else if (i < 0) { + break; + } + } + } else { + result_str = s.substr(start, end - start); + } + return result_str; + + } else if (target_value.is_array()) { + auto result = Value::array(); + if (reverse) { + for (int64_t i = start; i > end; --i) { + if (i >= 0 && i < (int64_t)len) { + result.push_back(target_value.at(i)); + } else if (i < 0) { + break; + } + } + } else { + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + } + return result; + } else { + if(target_value.is_null()) { + MNN_ERROR("Cannot subscript null\n"); + } else { + MNN_ERROR("Subscripting only supported on arrays and strings\n"); + } + } + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (base->mType == Expression::Type_Variable) { + auto t = (VariableExpr*)(base.get()); + _printlog("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); + } + _printlog("Trying to access property '" + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + return Value(); + } +}; + +class UnaryOpExpr : public Expression { +public: + enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; + std::shared_ptr expr; + Op op; + UnaryOpExpr(const Location & loc, std::shared_ptr && e, Op o) + : Expression(loc, Expression::Type_Unary), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!expr) _printlog("UnaryOpExpr.expr is null"); + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: return e; + case Op::Minus: return -e; + case Op::LogicalNot: return !e.to_bool(); + case Op::Expansion: + case Op::ExpansionDict: + _printlog("Expansion operator is only supported in function calls and collections"); + + } + _printlog("Unknown unary operator"); + return Value(); + } +}; + +class BinaryOpExpr : public Expression { +public: + enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; +private: + std::shared_ptr left; + std::shared_ptr right; + Op op; +public: + BinaryOpExpr(const Location & loc, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(loc, Expression::Type_Binary), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!left) _printlog("BinaryOpExpr.left is null"); + if (!right) _printlog("BinaryOpExpr.right is null"); + auto l = left->evaluate(context); + + auto do_eval = [&](const Value & l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = (VariableExpr*)(right.get()); + if (right->mType != Expression::Type_Variable) { + _printlog("Right side of 'is' operator must be a variable"); + } + + auto eval = [&]() { + const auto & name = t->get_name(); + if (name == "none") return l.is_null(); + if (name == "boolean") return l.is_boolean(); + if (name == "integer") return l.is_number_integer(); + if (name == "float") return l.is_number_float(); + if (name == "number") return l.is_number(); + if (name == "string") return l.is_string(); + if (name == "mapping") return l.is_object(); + if (name == "iterable") return l.is_iterable(); + if (name == "sequence") return l.is_array(); + if (name == "defined") return !l.is_null(); + if (name == "false") return !l.get(); + if (name == "true") return l.get(); + _printlog("Unknown type for 'is' operator: " + name); + return false; + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) return l; + return right->evaluate(context); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: return l.to_str() + r.to_str(); + case Op::Add: return l + r; + case Op::Sub: return l - r; + case Op::Mul: return l * r; + case Op::Div: return l / r; + case Op::MulMul: return std::pow(l.get(), r.get()); + case Op::DivDiv: return l.get() / r.get(); + case Op::Mod: return l.get() % r.get(); + case Op::Eq: return l == r; + case Op::Ne: return l != r; + case Op::Lt: return l < r; + case Op::Gt: return l > r; + case Op::Le: return l <= r; + case Op::Ge: return l >= r; + case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: return !(r.is_array() && r.contains(l)); + default: break; + } + _printlog("Unknown binary operator"); + return false; + }; + + if (l.is_callable()) { + return Value::callable([l, do_eval](const std::shared_ptr & context, ArgumentsValue & args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +struct ArgumentsExpression { + std::vector> args; + std::vector>> kwargs; + + ArgumentsValue evaluate(const std::shared_ptr & context) const { + ArgumentsValue vargs; + for (const auto& arg : this->args) { + if (arg->mType == Expression::Type_Unary) { + auto un_expr = (UnaryOpExpr*)(arg.get()); + if (un_expr->op == UnaryOpExpr::Op::Expansion) { + auto array = un_expr->expr->evaluate(context); + if (!array.is_array()) { + _printlog("Expansion operator only supported on arrays"); + } + array.for_each([&](Value & value) { + vargs.args.push_back(value); + }); + continue; + } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { + auto dict = un_expr->expr->evaluate(context); + if (!dict.is_object()) { + _printlog("ExpansionDict operator only supported on objects"); + } + dict.for_each([&](const Value & key) { + vargs.kwargs.push_back({key.get(), dict.at(key)}); + }); + continue; + } + } + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& iter : this->kwargs) { + const auto& name = iter.first; + const auto& value = iter.second; + vargs.kwargs.push_back({name, value->evaluate(context)}); + } + return vargs; + } +}; + +static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) { + auto charset = chars.empty() ? " \t\n\r" : chars; + auto start = left ? s.find_first_not_of(charset) : 0; + if (start == std::string::npos) return ""; + auto end = right ? s.find_last_not_of(charset) : s.size() - 1; + return s.substr(start, end - start + 1); +} + +static std::vector split(const std::string & s, const std::string & sep) { + std::vector result; + size_t start = 0; + size_t end = s.find(sep); + while (end != std::string::npos) { + result.push_back(s.substr(start, end - start)); + start = end + sep.length(); + end = s.find(sep, start); + } + result.push_back(s.substr(start)); + return result; +} + +static std::string capitalize(const std::string & s) { + if (s.empty()) return s; + auto result = s; + result[0] = std::toupper(result[0]); + return result; +} + +static std::string html_escape(const std::string & s) { + std::string result; + result.reserve(s.size()); + for (const auto & c : s) { + switch (c) { + case '&': result += "&"; break; + case '<': result += "<"; break; + case '>': result += ">"; break; + case '"': result += """; break; + case '\'': result += "'"; break; + default: result += c; break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::shared_ptr object; + std::shared_ptr method; + ArgumentsExpression args; +public: + MethodCallExpr(const Location & loc, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(loc, Expression::Type_MethodCall), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) _printlog("MethodCallExpr.object is null"); + if (!method) _printlog("MethodCallExpr.method is null"); + auto obj = object->evaluate(context); + auto vargs = args.evaluate(context); + if (obj.is_null()) { + // _printlog("Trying to call method '" + method->get_name() + "' on null"); + return Value(); + } + if (obj.is_array()) { + if (method->get_name() == "append") { + vargs.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(vargs.args[0]); + return Value(); + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {0, 1}, {0, 0}); + return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]); + } else if (method->get_name() == "insert") { + vargs.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = vargs.args[0].get(); + if (index < 0 || index > (int64_t) obj.size()) _printlog("Index out of range for insert method"); + obj.insert(index, vargs.args[1]); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + vargs.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {1, 1}, {0, 0}); + return obj.pop(vargs.args[0]); + } else if (method->get_name() == "get") { + vargs.expectArgs("get method", {1, 2}, {0, 0}); + auto key = vargs.args[0]; + if (vargs.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : vargs.args[1]; + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + _printlog("Property '" + method->get_name() + "' is not callable"); + } + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + auto str = obj.get(); + if (method->get_name() == "strip") { + vargs.expectArgs("strip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars)); + } else if (method->get_name() == "lstrip") { + vargs.expectArgs("lstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ true, /* right= */ false)); + } else if (method->get_name() == "rstrip") { + vargs.expectArgs("rstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ false, /* right= */ true)); + } else if (method->get_name() == "split") { + vargs.expectArgs("split method", {1, 1}, {0, 0}); + auto sep = vargs.args[0].get(); + auto parts = split(str, sep); + Value result = Value::array(); + for (const auto& part : parts) { + result.push_back(Value(part)); + } + return result; + } else if (method->get_name() == "capitalize") { + vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); + return Value(capitalize(str)); + } else if (method->get_name() == "endswith") { + vargs.expectArgs("endswith method", {1, 1}, {0, 0}); + auto suffix = vargs.args[0].get(); + return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "startswith") { + vargs.expectArgs("startswith method", {1, 1}, {0, 0}); + auto prefix = vargs.args[0].get(); + return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin()); + } else if (method->get_name() == "title") { + vargs.expectArgs("title method", {0, 0}, {0, 0}); + auto res = str; + for (size_t i = 0, n = res.size(); i < n; ++i) { + if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]); + else res[i] = std::tolower(res[i]); + } + return res; + } + } + // _printlog("Unknown method: " + method->get_name()); + return Value(); + } +}; + +class CallExpr : public Expression { +public: + std::shared_ptr object; + ArgumentsExpression args; + CallExpr(const Location & loc, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(loc, Expression::Type_Call), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) { + _printlog("CallExpr.object is null"); + return Value(); + } + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + //_printlog("Object is not callable: " + obj.dump(2)); + return Value(); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; +public: + FilterExpr(const Location & loc, std::vector> && p) + : Expression(loc, Expression::Type_Filter), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (!part) _printlog("FilterExpr.part is null"); + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (part->mType == Expression::Type_Call) { + auto ce = (CallExpr*)(part.get()); + auto target = ce->object->evaluate(context); + ArgumentsValue args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + ArgumentsValue args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::shared_ptr && e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + _printlog("Too many positional params for " + fn_name); + } + } + for (auto & iter : args.kwargs) { + auto& name = iter.first; + auto& value = iter.second; + auto named_pos_it = named_positions.find(name); + if (named_pos_it == named_positions.end()) { + _printlog("Unknown argument " + name + " for function " + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(name, value); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { + AUTOTIME; + return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); +} +std::shared_ptr parse(const std::string& template_str, const Options & options); +} // namespace minja diff --git a/transformers/llm/engine/src/omni.cpp b/transformers/llm/engine/src/omni.cpp index eb51836e..80103cd5 100644 --- a/transformers/llm/engine/src/omni.cpp +++ b/transformers/llm/engine/src/omni.cpp @@ -649,10 +649,13 @@ VARP Omni::gen_position_ids(int seq_len) { positionIds = _Input({3, seq_len}, NCHW, halide_type_of()); } auto ptr = positionIds->writeMap(); - if (seq_len == 1) { - ptr[0] = mContext->gen_seq_len + mPositionIds.back(); - ptr[1] = ptr[0]; - ptr[2] = ptr[0]; + if (mContext->gen_seq_len > 0) { + for (int i=0; igen_seq_len + mPositionIds.back() + i; + ptr[i + 0] = pos; + ptr[i + seq_len] = pos; + ptr[i + seq_len * 2] = pos; + } } else { for (int i = 0; i < seq_len; i++) { ptr[i] = mPositionIds.mT[i]; @@ -666,23 +669,12 @@ VARP Omni::gen_position_ids(int seq_len) { return positionIds; } -Express::VARP Omni::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) { - VARP logits; - auto logitsIndex = _var({-1}, {1}); - if (mConfig->all_logits()) { - logitsIndex = _var({0}, {1}); - } - std::vector outputs; - outputs = mCurrentModules.back()->onForward({hiddenState, mask, inputPos, logitsIndex}); - if (outputs.empty()) { - return nullptr; - } +std::vector Omni::forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) { + auto outputs = Llm::forwardRaw(hiddenState, mask, inputPos); if (mTalker && outputs.size() > 1) { mTalker->addTalkerEmbeds(outputs[1]); } - logits = outputs[0]; - mMeta->sync(); - return logits; + return outputs; } void Omni::response(const std::vector& input_ids, std::ostream* os, const char* end_with, int max_new_tokens) { @@ -733,6 +725,7 @@ void Omni::generateWavform() { void Talker::load() { initRuntime(); + mSeqLenIndex = 1; set_config("{\"sampler_type\": \"mixed\", \"temperature\": 0.9, \"topK\": 40, \"topP\": 0.8, \"penalty\": 1.05}"); mSampler.reset(Sampler::createSampler(mContext, mConfig)); mDiskEmbedding.reset(new DiskEmbedding(mConfig, mConfig->talker_embedding_file())); @@ -753,7 +746,9 @@ void Talker::load() { module_config.shapeMutable = false; module_config.rearrange = true; mModules.resize(1); - mModules[0].reset(Module::load({"inputs_embeds", "attention_mask", "position_ids"}, + std::vector inputNames {"inputs_embeds", "attention_mask", "position_ids", "logits_index"}; + + mModules[0].reset(Module::load(inputNames, {"logits"}, mConfig->talker_model().c_str(), mRuntimeManager, &module_config)); // dit mPreDit.reset(Module::load({"cond", "spk", "code"}, {"code_embeds", "rope", "mask"}, @@ -770,24 +765,7 @@ void Talker::load() { void Talker::generate_init(std::ostream* os, const char* end_with) { if (!doGenerate()) { return; } - { - mContext->os = os; - if (nullptr != end_with) { - mContext->end_with = end_with; - } - if (!mContext->generate_str.empty()) { - mContext->generate_str.clear(); - } - mContext->gen_seq_len = 0; - mContext->prefill_us = 0; - mContext->decode_us = 0; - mContext->current_token = 0; - mContext->all_seq_len = 0; - mContext->history_tokens.clear(); - mMeta->remove = mMeta->previous; - mContext->output_tokens.clear(); - mCurrentModules = mPrefillModules; - } + Llm::generate_init(os, end_with); // stream generate init mTalkerEmbeds.clear(); if (mInitialNoise.empty()) { @@ -947,19 +925,6 @@ int Talker::sample(Express::VARP logits, int offset, int size) { return token; } -VARP Talker::forward(VARP input_embeds) { - auto input_shape = input_embeds->getInfo()->dim; - int seq_len = input_shape[1]; - mMeta->add = seq_len; - auto attention_mask = gen_attention_mask(seq_len); - auto position_ids = gen_position_ids(seq_len); - auto outputs = mCurrentModules.back()->onForward({input_embeds, attention_mask, position_ids}); - mContext->all_seq_len += seq_len; - mContext->gen_seq_len++; - mMeta->sync(); - return outputs[0]; -} - void Talker::generate() { if (!doGenerate()) { return; } mTalkerEmbeds.push_back(mTextEos); @@ -970,11 +935,13 @@ void Talker::generate() { mContext->prompt_len = input_embeds->getInfo()->dim[1]; MNN::Timer _t; auto logits = forward(input_embeds); - int token = sample(logits); + mContext->current_token = sample(logits); + mContext->history_tokens.push_back(mContext->current_token); + mContext->output_tokens.push_back(mContext->current_token); mContext->prefill_us += _t.durationInUs(); _t.reset(); for (int i = 1; i < mMaxNewTokens; i++) { - input_embeds = embedding({token}); + input_embeds = embedding({mContext->current_token}); if (i + 1 < mTalkerEmbeds.size()) { input_embeds = input_embeds + mTalkerEmbeds[i + 1]; } else { @@ -982,8 +949,11 @@ void Talker::generate() { input_embeds = input_embeds + mTextPad; } auto logits = forward(input_embeds); - token = sample(logits); - if (token == 8292 || token == 8294) { + mContext->current_token = sample(logits); + mContext->history_tokens.push_back(mContext->current_token); + mContext->output_tokens.push_back(mContext->current_token); + + if (mContext->current_token == 8292 || mContext->current_token == 8294) { break; } } diff --git a/transformers/llm/engine/src/omni.hpp b/transformers/llm/engine/src/omni.hpp index ef0f916e..2b5eaeaf 100644 --- a/transformers/llm/engine/src/omni.hpp +++ b/transformers/llm/engine/src/omni.hpp @@ -69,7 +69,6 @@ public: VARP bigvganForward(VARP mel); VARP token2wav(const std::vector& codec_tokens); void token2wav(bool talker_done = false); - VARP forward(VARP input_embeds); void generate(); void setPostionIds(const MropeInfo& positionIds); void addTalkerEmbeds(VARP talker_embeds); @@ -104,7 +103,7 @@ public: mAudioModule.reset(); } virtual void load() override; - virtual Express::VARP forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override; + virtual std::vector forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override; virtual std::vector tokenizer_encode(const std::string& query) override; virtual Express::VARP embedding(const std::vector& input_ids) override; virtual Express::VARP gen_position_ids(int seq_len) override; diff --git a/transformers/llm/engine/src/prompt.cpp b/transformers/llm/engine/src/prompt.cpp index d842d98b..eb80c7e9 100644 --- a/transformers/llm/engine/src/prompt.cpp +++ b/transformers/llm/engine/src/prompt.cpp @@ -1,8 +1,38 @@ #include "prompt.hpp" - +#ifdef LLM_USE_MINJA +#include "minja/chat_template.hpp" +#endif namespace MNN { namespace Transformer { - +#ifdef LLM_USE_MINJA +class Prompt::JinjaTemplate { +public: + JinjaTemplate(const std::string& chatTemplate, const std::string& bos, const std::string& eos) : mTemplate(chatTemplate, bos, eos) { + // Do nothing + } + ~ JinjaTemplate() { + // Do nothing + } + void setExtraContext(const rapidjson::Value& extra) { + mInput.extra_context.CopyFrom(extra, mInput.extra_context.GetAllocator()); + } + std::string apply(const std::vector& inputs, bool addGeneration) { + mInput.messages.SetArray(); + for (auto& message : inputs) { + rapidjson::Value value; + value.SetObject(); + value.AddMember("role", rapidjson::StringRef(message.first.c_str()), mInput.messages.GetAllocator()); + value.AddMember("content", rapidjson::StringRef(message.second.c_str()), mInput.messages.GetAllocator()); + mInput.messages.PushBack(value, mInput.messages.GetAllocator()); + } + mInput.add_generation_prompt = addGeneration; + return mTemplate.apply(mInput); + } +private: + minja::chat_template mTemplate; + minja::chat_template_inputs mInput; +}; +#endif static std::string buildPrompt(ChatMessage item, std::string prompt_template, std::string placeholder) { size_t start_pos = prompt_template.find(placeholder); if (start_pos == std::string::npos) { @@ -22,7 +52,28 @@ bool contains(const std::string& str, const std::string& substring) { } void Prompt::setParams(std::shared_ptr config) { - mReuseKV = config->reuse_kv(); +#ifdef LLM_USE_MINJA + if (config->config_.document.HasMember("jinja")) { + auto& document = config->config_.document["jinja"]; + if (nullptr == mCommonTemplate.get()) { + // Only create jinja once + std::string bosToken, eosToken; + if (document.HasMember("bos") && document["bos"].IsString()) { + bosToken = document["bos"].GetString(); + } + if (document.HasMember("eos") && document["eos"].IsString()) { + eosToken = document["eos"].GetString(); + } + std::string templateChat = document["chat_template"].GetString(); + mCommonTemplate.reset(new JinjaTemplate(templateChat, bosToken, eosToken)); + } + if (document.HasMember("context")) { + mCommonTemplate->setExtraContext(document["context"]); + } + return; + } +#endif + mCommonTemplate.reset(); mSystemPrompt = config->system_prompt(); if (config->config_.document.HasMember("prompt_template")) { // std::cout << "legacy prompt_template" << std::endl; @@ -125,7 +176,12 @@ std::string Prompt::applyTemplate(std::string user_content, bool add_system_prom return applyTemplate(prompts, add_generation_prompt); } -std::string Prompt::applyTemplate(std::vector inputs, bool add_generation_prompt) { +std::string Prompt::applyTemplate(const std::vector& inputs, bool add_generation_prompt) { +#ifdef LLM_USE_MINJA + if (nullptr != mCommonTemplate.get()) { + return mCommonTemplate->apply(inputs, add_generation_prompt); + } +#endif std::string prompt_str = mBos; for (auto input : inputs) { if (input.first == "") continue; diff --git a/transformers/llm/engine/src/prompt.hpp b/transformers/llm/engine/src/prompt.hpp index 19447c61..c086cf15 100644 --- a/transformers/llm/engine/src/prompt.hpp +++ b/transformers/llm/engine/src/prompt.hpp @@ -13,8 +13,9 @@ namespace MNN { namespace Transformer { -class MNN_PUBLIC Prompt { +class Prompt { private: + class JinjaTemplate; std::shared_ptr mContext; std::string mPromptTemplate; // for compatibility std::string mSystemPrompt; @@ -23,14 +24,14 @@ private: std::string mSystemName = "system", mUserName = "user", mAssistantName = "assistant"; - bool mReuseKV = false; + std::shared_ptr mCommonTemplate; public: static Prompt* createPrompt(std::shared_ptr context, std::shared_ptr config); Prompt(std::shared_ptr context, std::shared_ptr config); std::string getAssistantSuffix() const; void setParams(std::shared_ptr config); std::string applyTemplate(std::string user_content, bool add_system_prompt = false, bool add_generation_prompt = true); - std::string applyTemplate(std::vector inputs, bool add_generation_prompt = true); + std::string applyTemplate(const std::vector& inputs, bool add_generation_prompt = true); }; } diff --git a/transformers/llm/export/llmexport.py b/transformers/llm/export/llmexport.py index f37b5c5a..eaed33eb 100644 --- a/transformers/llm/export/llmexport.py +++ b/transformers/llm/export/llmexport.py @@ -101,11 +101,9 @@ class LlmExporter(torch.nn.Module): elif 'SmolVLM2' in model_path: from transformers import AutoModelForImageTextToText self.model = AutoModelForImageTextToText.from_pretrained(model_path, torch_dtype='auto').eval() - elif 'SmolVLM' in model_path: + elif 'SmolVLM' in model_path or 'SmolDocling' in model_path: from transformers import AutoModelForVision2Seq self.model = AutoModelForVision2Seq.from_pretrained(model_path, torch_dtype='auto').eval() - elif 'FastVLM' in model_path: - self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval() else: try: self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval() @@ -203,12 +201,14 @@ class LlmExporter(torch.nn.Module): "system_prompt_template": prompt_template['system'].format(content='%s'), 'user_prompt_template': prompt_template['user'].format(content='%s'), 'assistant_prompt_template': prompt_template['assistant'].format(content='%s'), + 'jinja':prompt_template['jinja'], 'is_visual': False } # load modules ModelMapper.do_map(self, self.model, self.model_map['model']) - self.tie_word_embeddings = self.args.tie_embed and self.lm_.weight.equal(self.embed_.weight) + self.tie_word_embeddings = not self.args.seperate_embed and self.lm_.weight.equal(self.embed_.weight) if self.tie_word_embeddings: + print("Tie word embeddings in lm, set lm quant bit to 8") self.args.lm_quant_bit = 8 # rebuild modules if self.lm_ is None: @@ -349,6 +349,12 @@ class LlmExporter(torch.nn.Module): 'user': '{content}', 'assistant': '{content}' } + template['jinja'] = {} + template['jinja']['chat_template'] = self.tokenizer.get_chat_template() + if None != self.tokenizer.bos_token: + template['jinja']['bos'] = self.tokenizer.bos_token + if None != self.tokenizer.eos_token: + template['jinja']['eos'] = self.tokenizer.eos_token if self.model_type == 'baichuan': template['user'] = '{content}' template['assistant'] = '{content}' @@ -505,8 +511,8 @@ class LlmExporter(torch.nn.Module): attention_mask, position_ids, past_key_values, - cross_attention_states, - cross_attention_mask) + cross_attention_states = cross_attention_states, + cross_attention_mask = cross_attention_mask) token_id = torch.argmax(logits[:,-1,:]) if token_id in self.stop_ids: print("", end='\n') @@ -548,6 +554,8 @@ class LlmExporter(torch.nn.Module): "precision": "low", "memory": "low", # "system_prompt": "You are a helpful assistant.", + "sampler_type":'penalty', + "penalty":1.1 } if self.talker is not None: config['system_prompt'] = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." @@ -1060,13 +1068,13 @@ def export(path, export = 'onnx', onnx_slim = False, quant_bit = 4, - quant_block = 128, + quant_block = 64, lm_quant_bit = None, mnnconvert = None, ppl = False, awq = False, sym = False, - tie_embed = False, + seperate_embed = False, lora_split = False): args = argparse.Namespace() for k, v in { @@ -1085,7 +1093,7 @@ def export(path, 'ppl': ppl, 'awq': awq, 'sym': sym, - 'tie_embed': tie_embed, + 'seperate_embed': seperate_embed, 'lora_split': lora_split }.items(): setattr(args, k, v) @@ -1115,13 +1123,13 @@ def main(): parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.') parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.') parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.') - parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, 0 mean channle-wise, default is 128.') + parser.add_argument('--quant_block', type=int, default=64, help='mnn quant block, 0 mean channle-wise, default is 64.') parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.') parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.') parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.') parser.add_argument('--awq', action='store_true', help='Whether or not to use awq quant.') parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.') - parser.add_argument('--tie_embed', action='store_true', help='Whether or not to using tie_embedding, defualt is False, if True, lm_quant_bit will be 8.') + parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.') parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.') args = parser.parse_args() diff --git a/transformers/llm/export/utils/vision.py b/transformers/llm/export/utils/vision.py index bc91d0a3..88401217 100644 --- a/transformers/llm/export/utils/vision.py +++ b/transformers/llm/export/utils/vision.py @@ -942,9 +942,8 @@ class Idefics3Vision(Vision): image = convert_to_rgb(image) image = to_numpy_array(image) resized_height, resized_width = self.get_size(self.image_height, self.image_width) - resized_height, resized_width = 1, 1 format = infer_channel_dimension_format(image) - resample = PILImageResampling.BICUBIC + resample = PILImageResampling.LANCZOS global_image = resize(image, size=(self.patch_size, self.patch_size), resample=resample, input_data_format=format) def preprocess(image): image = rescale(image, scale=1 / 255.0, input_data_format=format) diff --git a/transformers/llm/resource/config/metal_lookahead.json b/transformers/llm/resource/config/metal_lookahead.json new file mode 100644 index 00000000..d72789cb --- /dev/null +++ b/transformers/llm/resource/config/metal_lookahead.json @@ -0,0 +1,15 @@ +{ + "llm_model": "llm.mnn", + "llm_weight": "llm.mnn.weight", + "backend_type": "metal", + "thread_num": 4, + "precision": "low", + "memory": "low", + "system_prompt": "You are a helpful assistant.", + "speculative_type": "lookahead", + "draft_predict_length": 8, + "draft_match_strictness": "low", + "ngram_match_maxlen": 8, + "draft_selection_rule": "fcfs", + "ngram_update": false +} diff --git a/transformers/llm/resource/config/mmap_config.json b/transformers/llm/resource/config/mmap_config.json new file mode 100644 index 00000000..82ee8fae --- /dev/null +++ b/transformers/llm/resource/config/mmap_config.json @@ -0,0 +1,10 @@ +{ + "backend_type": "cpu", + "thread_num": 4, + "precision": "low", + "memory": "low", + "power":"high", + "sampler_type":"penalty", + "penalty":1.1, + "use_mmap":true +} diff --git a/transformers/llm/resource/config/omni_lookahead_config.json b/transformers/llm/resource/config/omni_lookahead_config.json new file mode 100644 index 00000000..6b5f12ac --- /dev/null +++ b/transformers/llm/resource/config/omni_lookahead_config.json @@ -0,0 +1,26 @@ +{ + "llm_model": "llm.mnn", + "llm_weight": "llm.mnn.weight", + "backend_type": "cpu", + "thread_num": 4, + "precision": "low", + "memory": "low", + "system_prompt": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.", + "talker_max_new_tokens": 2048, + "talker_speaker": "Chelsie", + "dit_steps": 5, + "dit_solver": 1, + "mllm": { + "backend_type": "cpu", + "thread_num": 4, + "precision": "normal", + "memory": "low" + }, + "reuse_kv":false, + "speculative_type": "lookahead", + "draft_predict_length": 8, + "draft_match_strictness": "low", + "ngram_match_maxlen": 8, + "draft_selection_rule": "fcfs", + "ngram_update": false +} diff --git a/transformers/llm/resource/prompt/demo.txt b/transformers/llm/resource/prompt/demo.txt new file mode 100644 index 00000000..048fb691 --- /dev/null +++ b/transformers/llm/resource/prompt/demo.txt @@ -0,0 +1,4 @@ +计算8乘以12 +将下面的句子翻译成中文:It's a beautiful day to learn something new. +描述优秀的领导者应具备的五个特质,并解释每个特质为什么重要 +近年来,随着技术的快速发展和全球化的深入推进,数字经济已成为推动世界经济增长的新引擎。数字经济不仅改变了人们的生活方式,促进了信息和资源的快速流通,还重塑了传统行业的业务模式和竞争格局。尽管数字经济的发展为全球经济增长提供了新的动能,但同时也带来了数据安全、隐私保护、数字鸿沟和市场垄断等一系列挑战。考虑到这些背景,请详细分析数字经济在促进世界经济增长方面的作用,包括但不限于数字经济对提高生产效率、创造就业机会和促进可持续发展的贡献。同时,探讨如何应对数字经济发展过程中出现的挑战,具体包括如何保护个人数据安全和隐私、缩小数字鸿沟以确保数字经济的包容性和公平性,以及如何制定有效政策以避免市场垄断情况的出现,最终实现数字经济的健康和可持续发展。 diff --git a/transformers/llm/resource/prompt/prompt_code.txt b/transformers/llm/resource/prompt/prompt_code.txt new file mode 100644 index 00000000..24a8ec23 --- /dev/null +++ b/transformers/llm/resource/prompt/prompt_code.txt @@ -0,0 +1,29 @@ +user +```cpp +int partition(vector& arr, int low, int high) { + long pivot = arr[high]; + int i = low - 1; + + for (int j = low; j < high; ++j) { + if (arr[j] < pivot) { + ++i; + swap(arr[i], arr[j]); + } + } + + swap(arr[i + 1], arr[high]); + return i + 1; +} + +void quickSort(vector& arr, int low, int high) { + if (low < high) { + int pi = partition(arr, low, high); + + quickSort(arr, low, pi - 1); + + quickSort(arr, pi + 1, high); + } +} +``` + +将上述快速排序算法中,arr数据类型由long改为int,重新输出代码。要求:缩进空格数与prompt保持一致。只输出代码,不用原理解释。 \ No newline at end of file