MNN/source/backend/opencl/execution/cl/scale_mnn_cl.cpp

33 lines
1.4 KiB
C++
Raw Permalink Normal View History

2025-04-28 11:38:44 +08:00
#include "opencl_source_map.hpp"
namespace MNN {
const char* scale =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void scale(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,__read_only image2d_t scale,\n"
"#ifdef HAS_BIAS\n"
" __read_only image2d_t bias,/* cout%4*cout/4 */\n"
"#endif\n"
" __write_only image2d_t output) {\n"
" const int channel_block_idx=get_global_id(0);\n"
" const int w=get_global_id(1);\n"
" const int hb=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n"
" const int width=global_size_dim1;\n"
" const int pos=mad24(channel_block_idx,width,w);\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(pos,hb));\n"
" FLOAT4 scale_value=RI_F(scale,SAMPLER,(int2)(channel_block_idx,0));\n"
"#ifdef HAS_BIAS\n"
" FLOAT4 bias_value=RI_F(bias,SAMPLER,(int2)(channel_block_idx,0));\n"
" FLOAT4 out=in*scale_value+bias_value;\n"
"#else\n"
" FLOAT4 out=in*scale_value;\n"
"#endif\n"
" WI_F(output,(int2)(pos,hb),out);\n"
"}\n"
;
}