MNN/source/backend/opencl/execution/cl/gemm_mnn_cl.cpp

328 lines
13 KiB
C++

#include "opencl_source_map.hpp"
namespace MNN {
const char* gemm =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void gemm(__read_only image2d_t uInput,__read_only image2d_t uKernel,__write_only image2d_t uOutput,\n"
" __private const int width,__private const int height,__private const int multiLength,__private const int alpha2) {\n"
" \n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" if (pos.x<width*height && pos.y<alpha2) {\n"
" \n"
" const int pos_x=pos.x % width;\n"
" const int pos_y=pos.x/width;\n"
" const int pos_z=pos.y;\n"
" FLOAT4 o0=(FLOAT4)(0);\n"
" FLOAT4 o1=(FLOAT4)(0);\n"
" FLOAT4 o2=(FLOAT4)(0);\n"
" FLOAT4 o3=(FLOAT4)(0);\n"
" int kenerlY=mad24(pos_z,height,pos_y);\n"
" int srcY=mad24(pos_z,width,pos_x);\n"
" for (int k=0; k<multiLength; ++k) {\n"
" __private int index=mul24(k,4);\n"
" FLOAT4 k0=RI_F(uKernel,SAMPLER,(int2)(index,kenerlY));\n"
" FLOAT4 k1=RI_F(uKernel,SAMPLER,(int2)(index+1,kenerlY));\n"
" FLOAT4 k2=RI_F(uKernel,SAMPLER,(int2)(index+2,kenerlY));\n"
" FLOAT4 k3=RI_F(uKernel,SAMPLER,(int2)(index+3,kenerlY));\n"
" FLOAT4 s0=RI_F(uInput,SAMPLER,(int2)(index,srcY));\n"
" FLOAT4 s1=RI_F(uInput,SAMPLER,(int2)(index+1,srcY));\n"
" FLOAT4 s2=RI_F(uInput,SAMPLER,(int2)(index+2,srcY));\n"
" FLOAT4 s3=RI_F(uInput,SAMPLER,(int2)(index+3,srcY));\n"
" o0=mad(s0.x,k0,o0);\n"
" o0=mad(s0.y,k1,o0);\n"
" o0=mad(s0.z,k2,o0);\n"
" o0=mad(s0.w,k3,o0);\n"
" o1=mad(s1.x,k0,o1);\n"
" o1=mad(s1.y,k1,o1);\n"
" o1=mad(s1.z,k2,o1);\n"
" o1=mad(s1.w,k3,o1);\n"
" o2=mad(s2.x,k0,o2);\n"
" o2=mad(s2.y,k1,o2);\n"
" o2=mad(s2.z,k2,o2);\n"
" o2=mad(s2.w,k3,o2);\n"
" o3=mad(s3.x,k0,o3);\n"
" o3=mad(s3.y,k1,o3);\n"
" o3=mad(s3.z,k2,o3);\n"
" o3=mad(s3.w,k3,o3);\n"
" }\n"
" __private int out_y_idx=mul24(pos_y,4);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx+1),o1);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx+2),o2);\n"
" WI_F(uOutput,(int2)(srcY,out_y_idx+3),o3);\n"
" }\n"
"}\n"
"__kernel void gemmWinograd(__read_only image2d_t uInput,__read_only image2d_t uKernel,__write_only image2d_t uOutput,\n"
" __private const int unitWidth,__private const int unitHeight,__private const int dstChannelC4,__private const int multiLength,__private const int alpha2) {\n"
" \n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" const int unitWidth4=(unitWidth+3)/4;\n"
" if (pos.x<unitWidth4*unitHeight && pos.y<alpha2*dstChannelC4) {\n"
" \n"
" const int pos_x=pos.x % unitWidth4;\n"
" const int pos_y=pos.x/unitWidth4;\n"
" const int pos_z=pos.y % dstChannelC4;\n"
" const int pos_w=pos.y/dstChannelC4;\n"
" FLOAT4 o0=(FLOAT4)(0);\n"
" FLOAT4 o1=(FLOAT4)(0);\n"
" FLOAT4 o2=(FLOAT4)(0);\n"
" FLOAT4 o3=(FLOAT4)(0);\n"
" int srcY=mad24(pos_w,unitHeight,pos_y);\n"
" int srcX=pos_x << 2;\n"
" for (int k=0; k<multiLength; ++k) {\n"
" __private int index=mul24(k,4);\n"
" __private int x_offset=mul24(k,unitWidth);\n"
" FLOAT4 k0=RI_F(uKernel,SAMPLER,(int2)(index,pos.y));\n"
" FLOAT4 k1=RI_F(uKernel,SAMPLER,(int2)(index+1,pos.y));\n"
" FLOAT4 k2=RI_F(uKernel,SAMPLER,(int2)(index+2,pos.y));\n"
" FLOAT4 k3=RI_F(uKernel,SAMPLER,(int2)(index+3,pos.y));\n"
" FLOAT4 s0=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset,srcY));\n"
" FLOAT4 s1=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+1,srcY));\n"
" FLOAT4 s2=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+2,srcY));\n"
" FLOAT4 s3=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+3,srcY));\n"
" o0=mad(s0.x,k0,o0);\n"
" o0=mad(s0.y,k1,o0);\n"
" o0=mad(s0.z,k2,o0);\n"
" o0=mad(s0.w,k3,o0);\n"
" o1=mad(s1.x,k0,o1);\n"
" o1=mad(s1.y,k1,o1);\n"
" o1=mad(s1.z,k2,o1);\n"
" o1=mad(s1.w,k3,o1);\n"
" o2=mad(s2.x,k0,o2);\n"
" o2=mad(s2.y,k1,o2);\n"
" o2=mad(s2.z,k2,o2);\n"
" o2=mad(s2.w,k3,o2);\n"
" o3=mad(s3.x,k0,o3);\n"
" o3=mad(s3.y,k1,o3);\n"
" o3=mad(s3.z,k2,o3);\n"
" o3=mad(s3.w,k3,o3);\n"
" }\n"
" __private int out_y_idx=mad24(pos_z,unitHeight,pos_y);\n"
" __private int out_x_idx=mad24(pos_w,unitWidth,srcX);\n"
" const int remain=unitWidth-srcX;\n"
" if(remain >= 4){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" }else if(remain == 3){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" }else if(remain == 2){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" }else if(remain == 1){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" }\n"
" }\n"
"}\n"
"__kernel void gemmWinogradW2(__read_only image2d_t uInput,__read_only image2d_t uKernel,__write_only image2d_t uOutput,\n"
" __private const int unitWidth,__private const int unitHeight,__private const int dstChannelC4,__private const int multiLength,__private const int alpha2) {\n"
" \n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" const int unitWidth8=(unitWidth+7)/8;\n"
" if (pos.x<unitWidth8*unitHeight && pos.y<alpha2*dstChannelC4) {\n"
" \n"
" const int pos_x=pos.x % unitWidth8;\n"
" const int pos_y=pos.x/unitWidth8;\n"
" const int pos_z=pos.y % dstChannelC4;\n"
" const int pos_w=pos.y/dstChannelC4;\n"
" FLOAT4 o0=(FLOAT4)(0);\n"
" FLOAT4 o1=(FLOAT4)(0);\n"
" FLOAT4 o2=(FLOAT4)(0);\n"
" FLOAT4 o3=(FLOAT4)(0);\n"
" FLOAT4 o4=(FLOAT4)(0);\n"
" FLOAT4 o5=(FLOAT4)(0);\n"
" FLOAT4 o6=(FLOAT4)(0);\n"
" FLOAT4 o7=(FLOAT4)(0);\n"
" int srcY=mad24(pos_w,unitHeight,pos_y);\n"
" int srcX=pos_x << 3;\n"
" for (int k=0; k<multiLength; ++k) {\n"
" __private int index=mul24(k,4);\n"
" __private int x_offset=mul24(k,unitWidth);\n"
" FLOAT4 k0=RI_F(uKernel,SAMPLER,(int2)(index,pos.y));\n"
" FLOAT4 k1=RI_F(uKernel,SAMPLER,(int2)(index+1,pos.y));\n"
" FLOAT4 k2=RI_F(uKernel,SAMPLER,(int2)(index+2,pos.y));\n"
" FLOAT4 k3=RI_F(uKernel,SAMPLER,(int2)(index+3,pos.y));\n"
" FLOAT4 s0=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset,srcY));\n"
" FLOAT4 s1=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+1,srcY));\n"
" FLOAT4 s2=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+2,srcY));\n"
" FLOAT4 s3=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+3,srcY));\n"
" FLOAT4 s4=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+4,srcY));\n"
" FLOAT4 s5=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+5,srcY));\n"
" FLOAT4 s6=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+6,srcY));\n"
" FLOAT4 s7=RI_F(uInput,SAMPLER,(int2)(srcX+x_offset+7,srcY));\n"
" o0=mad(s0.x,k0,o0);\n"
" o0=mad(s0.y,k1,o0);\n"
" o0=mad(s0.z,k2,o0);\n"
" o0=mad(s0.w,k3,o0);\n"
" o1=mad(s1.x,k0,o1);\n"
" o1=mad(s1.y,k1,o1);\n"
" o1=mad(s1.z,k2,o1);\n"
" o1=mad(s1.w,k3,o1);\n"
" o2=mad(s2.x,k0,o2);\n"
" o2=mad(s2.y,k1,o2);\n"
" o2=mad(s2.z,k2,o2);\n"
" o2=mad(s2.w,k3,o2);\n"
" o3=mad(s3.x,k0,o3);\n"
" o3=mad(s3.y,k1,o3);\n"
" o3=mad(s3.z,k2,o3);\n"
" o3=mad(s3.w,k3,o3);\n"
" \n"
" o4=mad(s4.x,k0,o4);\n"
" o4=mad(s4.y,k1,o4);\n"
" o4=mad(s4.z,k2,o4);\n"
" o4=mad(s4.w,k3,o4);\n"
" o5=mad(s5.x,k0,o5);\n"
" o5=mad(s5.y,k1,o5);\n"
" o5=mad(s5.z,k2,o5);\n"
" o5=mad(s5.w,k3,o5);\n"
" o6=mad(s6.x,k0,o6);\n"
" o6=mad(s6.y,k1,o6);\n"
" o6=mad(s6.z,k2,o6);\n"
" o6=mad(s6.w,k3,o6);\n"
" o7=mad(s7.x,k0,o7);\n"
" o7=mad(s7.y,k1,o7);\n"
" o7=mad(s7.z,k2,o7);\n"
" o7=mad(s7.w,k3,o7);\n"
" }\n"
" __private int out_y_idx=mad24(pos_z,unitHeight,pos_y);\n"
" __private int out_x_idx=mad24(pos_w,unitWidth,srcX);\n"
" const int remain=unitWidth-srcX;\n"
" if(remain >= 8){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" WI_F(uOutput,(int2)(out_x_idx+5,out_y_idx),o5);\n"
" WI_F(uOutput,(int2)(out_x_idx+6,out_y_idx),o6);\n"
" WI_F(uOutput,(int2)(out_x_idx+7,out_y_idx),o7);\n"
" }else if(remain == 7){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" WI_F(uOutput,(int2)(out_x_idx+5,out_y_idx),o5);\n"
" WI_F(uOutput,(int2)(out_x_idx+6,out_y_idx),o6);\n"
" }else if(remain == 6){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" WI_F(uOutput,(int2)(out_x_idx+5,out_y_idx),o5);\n"
" }else if(remain == 5){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" WI_F(uOutput,(int2)(out_x_idx+4,out_y_idx),o4);\n"
" }else if(remain == 4){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" WI_F(uOutput,(int2)(out_x_idx+3,out_y_idx),o3);\n"
" }else if(remain == 3){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" WI_F(uOutput,(int2)(out_x_idx+2,out_y_idx),o2);\n"
" }else if(remain == 2){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" WI_F(uOutput,(int2)(out_x_idx+1,out_y_idx),o1);\n"
" }else if(remain == 1){\n"
" WI_F(uOutput,(int2)(out_x_idx,out_y_idx),o0);\n"
" }\n"
" }\n"
"}\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel void gemm_conv(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
" __global const FLOAT *weight,\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" FLOAT4 out=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(k,pos.y));\n"
" FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n"
" \n"
" out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n"
" out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n"
" out=mad((FLOAT4)in.z,(FLOAT4)weights.s89ab,out);\n"
" out=mad((FLOAT4)in.w,(FLOAT4)weights.scdef,out);\n"
" }\n"
" \n"
"#ifdef RELU\n"
" out=fmax(out,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out=clamp(out,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" WI_F(output,(int2)(pos.x,pos.y),out);\n"
"}\n"
"__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n"
" __read_only image2d_t input,\n"
" __global const FLOAT *weight,\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int dstChannelC4,\n"
" __private const int srcChannelC4,\n"
" __private const int batch\n"
") {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int pos_x=pos.x << 2;\n"
" int pos_y=pos.y << 1;\n"
" FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
" FLOAT4 out0=bias0,out1=bias0;\n"
" \n"
" int weight_offset=pos.x*16;\n"
" int weight_oc_offset=dstChannelC4*16;\n"
" for (int k=0; k<srcChannelC4; ++k) {\n"
" FLOAT4 in0=RI_F(input,SAMPLER,(int2)(k,pos_y));\n"
" FLOAT4 in1=RI_F(input,SAMPLER,(int2)(k,pos_y+1));\n"
" FLOAT16 weights=vload16(0,weight+weight_offset+k*weight_oc_offset);\n"
" \n"
" out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n"
" out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n"
" out0=mad((FLOAT4)in0.z,(FLOAT4)weights.s89ab,out0);\n"
" out0=mad((FLOAT4)in0.w,(FLOAT4)weights.scdef,out0);\n"
" \n"
" out1=mad((FLOAT4)in1.x,(FLOAT4)weights.s0123,out1);\n"
" out1=mad((FLOAT4)in1.y,(FLOAT4)weights.s4567,out1);\n"
" out1=mad((FLOAT4)in1.z,(FLOAT4)weights.s89ab,out1);\n"
" out1=mad((FLOAT4)in1.w,(FLOAT4)weights.scdef,out1);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" WI_F(output,(int2)(pos.x,pos_y),out0);\n"
" if(pos_y+1<batch)\n"
" WI_F(output,(int2)(pos.x,pos_y+1),out1);\n"
"}\n"
;
}