mirror of https://github.com/alibaba/MNN.git
71 lines
2.4 KiB
C++
71 lines
2.4 KiB
C++
|
#include "opencl_source_map.hpp"
|
||
|
namespace MNN {
|
||
|
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
||
|
const char* splitgelu_buf =
|
||
|
"#ifdef MNN_SUPPORT_FP16\n"
|
||
|
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
||
|
"#endif\n"
|
||
|
"__kernel void splitgelu_buf(__private int global_dim0,__private int global_dim1,\n"
|
||
|
" __global const FLOAT*input,\n"
|
||
|
" #ifdef DOUBLE_INPUTS\n"
|
||
|
" __global const FLOAT*input1,\n"
|
||
|
" #endif\n"
|
||
|
" __global FLOAT*output,\n"
|
||
|
" __private const int4 shape\n"
|
||
|
"){\n"
|
||
|
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
|
||
|
" if (pos.x<global_dim0 && pos.y<global_dim1) {\n"
|
||
|
" const int h=pos.x;\n"
|
||
|
" const int bc=pos.y;\n"
|
||
|
"// The product of W and H is a multiple of 16\n"
|
||
|
"#ifdef WH_16\n"
|
||
|
" const int in_offset=bc*shape.z*2+h*16;\n"
|
||
|
" const int out_offset=bc*shape.z+h*16;\n"
|
||
|
" float16 valueL=convert_float16(vload16(0,input+in_offset));\n"
|
||
|
" float16 valueR=convert_float16(vload16(0,input+in_offset+shape.z));\n"
|
||
|
" #ifdef DOUBLE_INPUTS\n"
|
||
|
" float16 valueConstL=convert_float16(vload16(h,input1));\n"
|
||
|
" float16 valueConstR=convert_float16(vload16(h,input1+shape.z));\n"
|
||
|
" valueL += valueConstL;\n"
|
||
|
" valueR += valueConstR;\n"
|
||
|
" #endif\n"
|
||
|
" float16 out=(erf(valueR*(float16)0.7071067932881648)+(float16)1.0)*valueR*(float16)0.5;\n"
|
||
|
" out *= valueL;\n"
|
||
|
" vstore16(CONVERT_FLOAT16(out),0,output+out_offset);\n"
|
||
|
"// The product of W and H is a multiple of 4\n"
|
||
|
"#elif defined (WH_4)\n"
|
||
|
" const int in_offset=bc*shape.z*2+h*4;\n"
|
||
|
" const int out_offset=bc*shape.z+h*4;\n"
|
||
|
" float4 valueL=convert_float4(vload4(0,input+in_offset));\n"
|
||
|
" float4 valueR=convert_float4(vload4(0,input+in_offset+shape.z));\n"
|
||
|
" #ifdef DOUBLE_INPUTS\n"
|
||
|
" float4 valueConstL=convert_float4(vload4(h,input1));\n"
|
||
|
" float4 valueConstR=convert_float4(vload4(h,input1+shape.z));\n"
|
||
|
" valueL += valueConstL;\n"
|
||
|
" valueR += valueConstR;\n"
|
||
|
" #endif\n"
|
||
|
" float4 out=(erf(valueR*(float4)0.7071067932881648)+(float4)1.0)*valueR*(float4)0.5;\n"
|
||
|
" out *= valueL;\n"
|
||
|
" vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n"
|
||
|
"#else\n"
|
||
|
" const int in_offset=bc*shape.z*2+h;\n"
|
||
|
" const int out_offset=bc*shape.z+h;\n"
|
||
|
" \n"
|
||
|
" float valueL=(float)input[in_offset];\n"
|
||
|
" float valueR=(float)input[in_offset+shape.z];\n"
|
||
|
" #ifdef DOUBLE_INPUTS\n"
|
||
|
" float valueConstL=input1[h];\n"
|
||
|
" float valueConstR=input1[shape.z+h];\n"
|
||
|
" valueL += valueConstL;\n"
|
||
|
" valueR += valueConstR;\n"
|
||
|
" #endif\n"
|
||
|
" float out=(erf(valueR*0.7071067932881648)+1.0)*valueR*0.5;\n"
|
||
|
" out *= valueL;\n"
|
||
|
" output[out_offset]=out;\n"
|
||
|
"#endif\n"
|
||
|
" }\n"
|
||
|
"}\n"
|
||
|
;
|
||
|
#endif
|
||
|
}
|