mirror of https://github.com/alibaba/MNN.git
186 lines
7.9 KiB
C++
186 lines
7.9 KiB
C++
#include "opencl_source_map.hpp"
|
|
namespace MNN {
|
|
const char* gemm_int =
|
|
"#ifdef MNN_SUPPORT_FP16\n"
|
|
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
|
"#endif\n"
|
|
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
|
|
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
|
|
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
|
"#ifdef INPUT_CHANNEL_LEAVE\n"
|
|
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
|
|
"#else\n"
|
|
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
|
|
"#endif\n"
|
|
"__kernel void gemm_conv(GLOBAL_SIZE_DIM2\n"
|
|
" __read_only image2d_t input,\n"
|
|
"#if QUANT_BIT == 8\n"
|
|
" __global const char *weight,\n"
|
|
" __global const float *dequantScaleOffset,\n"
|
|
"#else\n"
|
|
" __global const uchar *weight,\n"
|
|
" __global const float *dequantScaleOffset,\n"
|
|
"#endif\n"
|
|
" __read_only image2d_t bias,\n"
|
|
" __write_only image2d_t output,\n"
|
|
" __private const int dstChannelC4,\n"
|
|
" __private const int srcChannelC4,\n"
|
|
" __private const int batch\n"
|
|
" ,__private const int blockDim\n"
|
|
" ,__private const int srcChannel\n"
|
|
") {\n"
|
|
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
|
|
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
|
|
" FLOAT4 out=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
|
|
"#if QUANT_BIT == 8\n"
|
|
" int weight_offset=pos.x*16;\n"
|
|
" int weight_oc_offset=dstChannelC4*16;\n"
|
|
"#else \n"
|
|
" int weight_offset=pos.x*8;\n"
|
|
" int weight_oc_offset=dstChannelC4*8;\n"
|
|
"#endif\n"
|
|
" for (int k=0; k<srcChannelC4; ++k) {\n"
|
|
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
|
|
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
|
|
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
|
|
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
|
|
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
|
|
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
|
|
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
|
|
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
|
|
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
|
|
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
|
|
" FLOAT4 in=RI_F(input,SAMPLER,(int2)(k,pos.y));\n"
|
|
"#if QUANT_BIT == 8\n"
|
|
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
|
|
"#else\n"
|
|
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
|
|
" char16 charWeights=0;\n"
|
|
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
|
|
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
|
|
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
|
|
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
|
|
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
|
|
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
|
|
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
|
|
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
|
|
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
|
|
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
|
|
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
|
|
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
|
|
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
|
|
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
|
|
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
|
|
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
|
|
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
|
|
"#endif\n"
|
|
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
|
|
" \n"
|
|
" out=mad((FLOAT4)in.x,(FLOAT4)weights.s0123,out);\n"
|
|
" out=mad((FLOAT4)in.y,(FLOAT4)weights.s4567,out);\n"
|
|
" out=mad((FLOAT4)in.z,(FLOAT4)weights.s89ab,out);\n"
|
|
" out=mad((FLOAT4)in.w,(FLOAT4)weights.scdef,out);\n"
|
|
" }\n"
|
|
" \n"
|
|
"#ifdef RELU\n"
|
|
" out=fmax(out,(FLOAT4)0);\n"
|
|
"#endif\n"
|
|
"#ifdef RELU6\n"
|
|
" out=clamp(out,(FLOAT4)0,(FLOAT4)6);\n"
|
|
"#endif\n"
|
|
" WI_F(output,(int2)(pos.x,pos.y),out);\n"
|
|
"}\n"
|
|
"__kernel void gemm_conv_b2(GLOBAL_SIZE_DIM2\n"
|
|
" __read_only image2d_t input,\n"
|
|
"#if QUANT_BIT == 8\n"
|
|
" __global const char *weight,\n"
|
|
" __global const float *dequantScaleOffset,\n"
|
|
"#else\n"
|
|
" __global const uchar *weight,\n"
|
|
" __global const float *dequantScaleOffset,\n"
|
|
"#endif\n"
|
|
" __read_only image2d_t bias,\n"
|
|
" __write_only image2d_t output,\n"
|
|
" __private const int dstChannelC4,\n"
|
|
" __private const int srcChannelC4,\n"
|
|
" __private const int batch\n"
|
|
" ,__private const int blockDim\n"
|
|
" ,__private const int srcChannel\n"
|
|
") {\n"
|
|
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); //cout/4,b\n"
|
|
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
|
|
" int pos_x=pos.x << 2;\n"
|
|
" int pos_y=pos.y << 1;\n"
|
|
" FLOAT4 bias0=RI_F(bias,SAMPLER,(int2)(pos.x,0));\n"
|
|
" FLOAT4 out0=bias0,out1=bias0;\n"
|
|
" \n"
|
|
"#if QUANT_BIT == 8\n"
|
|
" int weight_offset=pos.x*16;\n"
|
|
" int weight_oc_offset=dstChannelC4*16;\n"
|
|
"#else\n"
|
|
" int weight_offset=pos.x*8;\n"
|
|
" int weight_oc_offset=dstChannelC4*8;\n"
|
|
"#endif\n"
|
|
" for (int k=0; k<srcChannelC4; ++k) {\n"
|
|
" int kindex=(k*4)/blockDim*dstChannelC4*8;\n"
|
|
" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(pos.x,dequantScaleOffset+kindex));\n"
|
|
" COMPUTE_FLOAT16 scale=(COMPUTE_FLOAT16)(ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
|
|
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
|
|
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6,\n"
|
|
" ScaleOffset.s0,ScaleOffset.s2,ScaleOffset.s4,ScaleOffset.s6);\n"
|
|
" COMPUTE_FLOAT16 offset=(COMPUTE_FLOAT16)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
|
|
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
|
|
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7,\n"
|
|
" ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n"
|
|
" FLOAT4 in0=RI_F(input,SAMPLER,(int2)(k,pos_y));\n"
|
|
" FLOAT4 in1=RI_F(input,SAMPLER,(int2)(k,pos_y+1));\n"
|
|
"#if QUANT_BIT == 8\n"
|
|
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,weight+weight_offset+k*weight_oc_offset))*scale+offset;\n"
|
|
"#else\n"
|
|
" uchar8 charWeightsInt4=vload8(0,weight+weight_offset+k*weight_oc_offset);\n"
|
|
" char16 charWeights=0;\n"
|
|
" charWeights.s0=(charWeightsInt4.s0 >> 4)-8;\n"
|
|
" charWeights.s1=(charWeightsInt4.s0 & 15)-8;\n"
|
|
" charWeights.s2=(charWeightsInt4.s1 >> 4)-8;\n"
|
|
" charWeights.s3=(charWeightsInt4.s1 & 15)-8;\n"
|
|
" charWeights.s4=(charWeightsInt4.s2 >> 4)-8;\n"
|
|
" charWeights.s5=(charWeightsInt4.s2 & 15)-8;\n"
|
|
" charWeights.s6=(charWeightsInt4.s3 >> 4)-8;\n"
|
|
" charWeights.s7=(charWeightsInt4.s3 & 15)-8;\n"
|
|
" charWeights.s8=(charWeightsInt4.s4 >> 4)-8;\n"
|
|
" charWeights.s9=(charWeightsInt4.s4 & 15)-8;\n"
|
|
" charWeights.sa=(charWeightsInt4.s5 >> 4)-8;\n"
|
|
" charWeights.sb=(charWeightsInt4.s5 & 15)-8;\n"
|
|
" charWeights.sc=(charWeightsInt4.s6 >> 4)-8;\n"
|
|
" charWeights.sd=(charWeightsInt4.s6 & 15)-8;\n"
|
|
" charWeights.se=(charWeightsInt4.s7 >> 4)-8;\n"
|
|
" charWeights.sf=(charWeightsInt4.s7 & 15)-8;\n"
|
|
" FLOAT16 weights=CONVERT_FLOAT16(charWeights)*scale+offset;\n"
|
|
"#endif\n"
|
|
" PADZEROSVEC(k,srcChannel,weights.s0123,weights.s4567,weights.s89ab,weights.scdef);\n"
|
|
" \n"
|
|
" out0=mad((FLOAT4)in0.x,(FLOAT4)weights.s0123,out0);\n"
|
|
" out0=mad((FLOAT4)in0.y,(FLOAT4)weights.s4567,out0);\n"
|
|
" out0=mad((FLOAT4)in0.z,(FLOAT4)weights.s89ab,out0);\n"
|
|
" out0=mad((FLOAT4)in0.w,(FLOAT4)weights.scdef,out0);\n"
|
|
" \n"
|
|
" out1=mad((FLOAT4)in1.x,(FLOAT4)weights.s0123,out1);\n"
|
|
" out1=mad((FLOAT4)in1.y,(FLOAT4)weights.s4567,out1);\n"
|
|
" out1=mad((FLOAT4)in1.z,(FLOAT4)weights.s89ab,out1);\n"
|
|
" out1=mad((FLOAT4)in1.w,(FLOAT4)weights.scdef,out1);\n"
|
|
" }\n"
|
|
"#ifdef RELU\n"
|
|
" out0=fmax(out0,(FLOAT4)0);\n"
|
|
" out1=fmax(out1,(FLOAT4)0);\n"
|
|
"#endif\n"
|
|
"#ifdef RELU6\n"
|
|
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
|
|
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
|
|
"#endif\n"
|
|
" WI_F(output,(int2)(pos.x,pos_y),out0);\n"
|
|
" if(pos_y+1<batch)\n"
|
|
" WI_F(output,(int2)(pos.x,pos_y+1),out1);\n"
|
|
"}\n"
|
|
;
|
|
}
|