MNN/source/backend/opencl/execution/cl/gemm_buf_mnn_cl.cpp

84 lines
2.9 KiB
C++

#include "opencl_source_map.hpp"
namespace MNN {
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* gemm_buf =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
"// [K/4,M,4] -> [alignK,alignM]\n"
"__kernel void transpose_pad(GLOBAL_SIZE_DIM2\n"
" const int alignM,\n"
" const int alignK,\n"
" const int M,\n"
" const int K,\n"
" const int area,\n"
" __global const FLOAT* input,\n"
" __global FLOAT* output\n"
" ) {\n"
" const int idx_m4=get_global_id(0); // idx M\n"
" const int idx_k4=get_global_id(1); // idx K\n"
" UNIFORM_BOUNDRY_CHECK(idx_m4,idx_k4);\n"
" const int idx_m=idx_m4 << 2;\n"
" const int idx_k=idx_k4 << 2;\n"
" const int K_4=(K+3) >> 2;\n"
" const int in_offset_base=(idx_k4*M+idx_m)*4;\n"
" const int out_offset_base=idx_k*alignM+idx_m;\n"
" \n"
" FLOAT4 m0k4=(idx_k4 >= K_4 || idx_m+0 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base);\n"
" FLOAT4 m1k4=(idx_k4 >= K_4 || idx_m+1 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+4);\n"
" FLOAT4 m2k4=(idx_k4 >= K_4 || idx_m+2 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+8);\n"
" FLOAT4 m3k4=(idx_k4 >= K_4 || idx_m+3 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+12);\n"
" \n"
" vstore4((FLOAT4)(m0k4.x,m1k4.x,m2k4.x,m3k4.x),0,output+out_offset_base);\n"
" vstore4((FLOAT4)(m0k4.y,m1k4.y,m2k4.y,m3k4.y),0,output+out_offset_base+alignM);\n"
" vstore4((FLOAT4)(m0k4.z,m1k4.z,m2k4.z,m3k4.z),0,output+out_offset_base+alignM+alignM);\n"
" vstore4((FLOAT4)(m0k4.w,m1k4.w,m2k4.w,m3k4.w),0,output+out_offset_base+alignM+alignM+alignM);\n"
"}\n"
"#ifndef M_VEC\n"
"#define M_VEC 1\n"
"#endif\n"
"// [alignM,alignN] -> [N/4,B,area,N4] (M=B*area)\n"
"__kernel void transpose_bias(GLOBAL_SIZE_DIM2\n"
" const int alignM,\n"
" const int alignN,\n"
" const int M,\n"
" const int N,\n"
" const int area,\n"
" __global const FLOAT* input0,\n"
" __global const FLOAT* input1,\n"
" __global FLOAT* output\n"
" #ifdef PRELU\n"
" ,__global const FLOAT *slope_ptr\n"
" #endif\n"
" ) {\n"
" int idx_m=get_global_id(0); // idx M\n"
" int idx_n4=get_global_id(1); // idx N\n"
" UNIFORM_BOUNDRY_CHECK(idx_m,idx_n4);\n"
" const int idx_n=idx_n4 << 2;\n"
" idx_m=idx_m*M_VEC;\n"
" FLOAT4 res1=vload4(0,input1+idx_n);\n"
" #ifdef PRELU\n"
" FLOAT4 slope_in=vload4(0,slope_ptr+idx_n);\n"
" #endif\n"
" #pragma unroll\n"
" for(int i=0; i<M_VEC; i++) {\n"
" FLOAT4 res0=vload4(0,input0+(idx_m+i)*alignN+idx_n);\n"
" FLOAT4 res=res0+res1;\n"
" #ifdef RELU\n"
" res=fmax(res,(FLOAT4)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n"
" #endif\n"
" #ifdef PRELU\n"
" res=select(res*slope_in,res,res >= 0);\n"
" #endif\n"
" vstore4(res,0,output+((idx_n4*M+idx_m+i) << 2));\n"
" }\n"
"}\n"
;
#endif
}