MNN/source/backend/opencl/execution/cl/splitgelu_buf_mnn_cl.cpp

71 lines
2.4 KiB
C++

#include "opencl_source_map.hpp"
namespace MNN {
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* splitgelu_buf =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"__kernel void splitgelu_buf(__private int global_dim0,__private int global_dim1,\n"
" __global const FLOAT*input,\n"
" #ifdef DOUBLE_INPUTS\n"
" __global const FLOAT*input1,\n"
" #endif\n"
" __global FLOAT*output,\n"
" __private const int4 shape\n"
"){\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
" const int h=pos.x;\n"
" const int bc=pos.y;\n"
"// The product of W and H is a multiple of 16\n"
"#ifdef WH_16\n"
" const int in_offset=bc*shape.z*2+h*16;\n"
" const int out_offset=bc*shape.z+h*16;\n"
" float16 valueL=convert_float16(vload16(0,input+in_offset));\n"
" float16 valueR=convert_float16(vload16(0,input+in_offset+shape.z));\n"
" #ifdef DOUBLE_INPUTS\n"
" float16 valueConstL=convert_float16(vload16(h,input1));\n"
" float16 valueConstR=convert_float16(vload16(h,input1+shape.z));\n"
" valueL += valueConstL;\n"
" valueR += valueConstR;\n"
" #endif\n"
" float16 out=(erf(valueR*(float16)0.7071067932881648)+(float16)1.0)*valueR*(float16)0.5;\n"
" out *= valueL;\n"
" vstore16(CONVERT_FLOAT16(out),0,output+out_offset);\n"
"// The product of W and H is a multiple of 4\n"
"#elif defined (WH_4)\n"
" const int in_offset=bc*shape.z*2+h*4;\n"
" const int out_offset=bc*shape.z+h*4;\n"
" float4 valueL=convert_float4(vload4(0,input+in_offset));\n"
" float4 valueR=convert_float4(vload4(0,input+in_offset+shape.z));\n"
" #ifdef DOUBLE_INPUTS\n"
" float4 valueConstL=convert_float4(vload4(h,input1));\n"
" float4 valueConstR=convert_float4(vload4(h,input1+shape.z));\n"
" valueL += valueConstL;\n"
" valueR += valueConstR;\n"
" #endif\n"
" float4 out=(erf(valueR*(float4)0.7071067932881648)+(float4)1.0)*valueR*(float4)0.5;\n"
" out *= valueL;\n"
" vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n"
"#else\n"
" const int in_offset=bc*shape.z*2+h;\n"
" const int out_offset=bc*shape.z+h;\n"
" \n"
" float valueL=(float)input[in_offset];\n"
" float valueR=(float)input[in_offset+shape.z];\n"
" #ifdef DOUBLE_INPUTS\n"
" float valueConstL=input1[h];\n"
" float valueConstR=input1[shape.z+h];\n"
" valueL += valueConstL;\n"
" valueR += valueConstR;\n"
" #endif\n"
" float out=(erf(valueR*0.7071067932881648)+1.0)*valueR*0.5;\n"
" out *= valueL;\n"
" output[out_offset]=out;\n"
"#endif\n"
" }\n"
"}\n"
;
#endif
}