mirror of https://github.com/alibaba/MNN.git
348 lines
12 KiB
C++
348 lines
12 KiB
C++
#include "opencl_source_map.hpp"
|
|
namespace MNN {
|
|
#ifndef MNN_OPENCL_BUFFER_CLOSED
|
|
const char* matmul_buf =
|
|
"#ifdef MNN_SUPPORT_FP16\n"
|
|
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
|
"#endif\n"
|
|
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
|
|
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) ""if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { ""return; ""}\n"
|
|
"__kernel void matmul_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a,\n"
|
|
" __global const FLOAT* input_b,\n"
|
|
" #ifdef BIAS\n"
|
|
" __global const FLOAT* input_c,\n"
|
|
" #endif\n"
|
|
" __global FLOAT* output_c,\n"
|
|
" __private const int M,\n"
|
|
" __private const int N,\n"
|
|
" __private const int K) {\n"
|
|
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); // N M\n"
|
|
" DEAL_NON_UNIFORM_DIM2(pos.x,pos.y);\n"
|
|
" const int idn=pos.x << 2;\n"
|
|
" const int idm=pos.y << 2;\n"
|
|
" \n"
|
|
" COMPUTE_FLOAT4 out[4];\n"
|
|
" #ifdef BIAS\n"
|
|
" COMPUTE_FLOAT4 bias=CONVERT_COMPUTE_FLOAT4(vload4(0,input_c+idn));\n"
|
|
" #pragma unroll\n"
|
|
" for(int i=0; i<4; ++i){\n"
|
|
" out[i]=bias;\n"
|
|
" }\n"
|
|
" #else\n"
|
|
" #pragma unroll\n"
|
|
" for(int i=0; i<4; ++i){\n"
|
|
" out[i]=(COMPUTE_FLOAT4)0;\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" const int K4=(K+3)/4;\n"
|
|
" #ifdef K_LEAVE\n"
|
|
" const int loop_end=max(K4-1,0);\n"
|
|
" const int remain=K-loop_end*4;\n"
|
|
" #else\n"
|
|
" const int loop_end=K4;\n"
|
|
" #endif\n"
|
|
" \n"
|
|
" #ifdef TRANSPOSE_A\n"
|
|
" __global const FLOAT* input_a_offset=input_a+idm; // K x M\n"
|
|
" #else\n"
|
|
" __global const FLOAT* input_a_offset=input_a+idm*K; // M x K\n"
|
|
" #endif\n"
|
|
" \n"
|
|
" #ifdef TRANSPOSE_B\n"
|
|
" __global const FLOAT* input_b_offset=input_b+idn*K; // N x K\n"
|
|
" #else\n"
|
|
" __global const FLOAT* input_b_offset=input_b+idn; // K x N\n"
|
|
" #endif\n"
|
|
" \n"
|
|
" for (int k=0; k<loop_end; ++k) {\n"
|
|
" int kindex=k << 2;\n"
|
|
" COMPUTE_FLOAT4 A[4]; // m4 x k4\n"
|
|
" COMPUTE_FLOAT4 B[4]; // k4 x n4\n"
|
|
" #ifdef M_LEAVE\n"
|
|
" if(idm+3 >= M){\n"
|
|
" #ifdef TRANSPOSE_A\n"
|
|
" #if M_LEAVE_NUM == 3\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT3 tmp0=CONVERT_COMPUTE_FLOAT3(vload3(0,input_a_offset+kindex*M));\n"
|
|
" COMPUTE_FLOAT3 tmp1=CONVERT_COMPUTE_FLOAT3(vload3(0,input_a_offset+(kindex+1)*M));\n"
|
|
" COMPUTE_FLOAT3 tmp2=CONVERT_COMPUTE_FLOAT3(vload3(0,input_a_offset+(kindex+2)*M));\n"
|
|
" COMPUTE_FLOAT3 tmp3=CONVERT_COMPUTE_FLOAT3(vload3(0,input_a_offset+(kindex+3)*M));\n"
|
|
" \n"
|
|
" A[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
|
|
" A[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
|
|
" A[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
|
|
" A[3]=(COMPUTE_FLOAT4)0;\n"
|
|
" }\n"
|
|
" #elif M_LEAVE_NUM == 2\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT2 tmp0=CONVERT_COMPUTE_FLOAT2(vload2(0,input_a_offset+kindex*M));\n"
|
|
" COMPUTE_FLOAT2 tmp1=CONVERT_COMPUTE_FLOAT2(vload2(0,input_a_offset+(kindex+1)*M));\n"
|
|
" COMPUTE_FLOAT2 tmp2=CONVERT_COMPUTE_FLOAT2(vload2(0,input_a_offset+(kindex+2)*M));\n"
|
|
" COMPUTE_FLOAT2 tmp3=CONVERT_COMPUTE_FLOAT2(vload2(0,input_a_offset+(kindex+3)*M));\n"
|
|
" \n"
|
|
" A[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
|
|
" A[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
|
|
" A[2]=(COMPUTE_FLOAT4)0;\n"
|
|
" A[3]=(COMPUTE_FLOAT4)0;\n"
|
|
" }\n"
|
|
" #elif M_LEAVE_NUM == 1\n"
|
|
" {\n"
|
|
" A[0]=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)input_a_offset[kindex*M],(COMPUTE_FLOAT)input_a_offset[(kindex+1)*M],(COMPUTE_FLOAT)input_a_offset[(kindex+2)*M],(COMPUTE_FLOAT)input_a_offset[(kindex+3)*M]);\n"
|
|
" A[1]=(COMPUTE_FLOAT4)0;\n"
|
|
" A[2]=(COMPUTE_FLOAT4)0;\n"
|
|
" A[3]=(COMPUTE_FLOAT4)0;\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" #else\n"
|
|
" #if M_LEAVE_NUM == 3\n"
|
|
" A[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex));\n"
|
|
" A[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+K));\n"
|
|
" A[2]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+2*K));\n"
|
|
" A[3]=(COMPUTE_FLOAT4)0;\n"
|
|
" #elif M_LEAVE_NUM == 2\n"
|
|
" A[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex));\n"
|
|
" A[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+K));\n"
|
|
" A[2]=(COMPUTE_FLOAT4)0;\n"
|
|
" A[3]=(COMPUTE_FLOAT4)0;\n"
|
|
" #elif M_LEAVE_NUM == 1\n"
|
|
" A[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex));\n"
|
|
" A[1]=(COMPUTE_FLOAT4)0;\n"
|
|
" A[2]=(COMPUTE_FLOAT4)0;\n"
|
|
" A[3]=(COMPUTE_FLOAT4)0;\n"
|
|
" #endif\n"
|
|
" #endif\n"
|
|
" } else\n"
|
|
" #endif\n"
|
|
" {\n"
|
|
" #ifdef TRANSPOSE_A\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex*M));\n"
|
|
" COMPUTE_FLOAT4 tmp1=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+(kindex+1)*M));\n"
|
|
" COMPUTE_FLOAT4 tmp2=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+(kindex+2)*M));\n"
|
|
" COMPUTE_FLOAT4 tmp3=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+(kindex+3)*M));\n"
|
|
" \n"
|
|
" A[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
|
|
" A[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
|
|
" A[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
|
|
" A[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
|
|
" }\n"
|
|
" #else\n"
|
|
" A[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex));\n"
|
|
" A[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+K));\n"
|
|
" A[2]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+2*K));\n"
|
|
" A[3]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+kindex+3*K));\n"
|
|
" #endif\n"
|
|
" }\n"
|
|
" \n"
|
|
" #ifdef N_LEAVE\n"
|
|
" if(idn+3 >= N){\n"
|
|
" #ifdef TRANSPOSE_B\n"
|
|
" #if N_LEAVE_NUM == 3\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex));\n"
|
|
" COMPUTE_FLOAT4 tmp1=idn+1 >= N ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+K));\n"
|
|
" COMPUTE_FLOAT4 tmp2=idn+2 >= N ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+2*K));\n"
|
|
" \n"
|
|
" B[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,0);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,0);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,0);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,tmp2.w,0);\n"
|
|
" }\n"
|
|
" #elif N_LEAVE_NUM == 2\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex));\n"
|
|
" COMPUTE_FLOAT4 tmp1=idn+1 >= N ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+K));\n"
|
|
" \n"
|
|
" B[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,0,0);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,0,0);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,0,0);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,0,0);\n"
|
|
" }\n"
|
|
" #elif N_LEAVE_NUM == 1\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex));\n"
|
|
" \n"
|
|
" B[0]=(COMPUTE_FLOAT4)(tmp0.x,0,0,0);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)(tmp0.y,0,0,0);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)(tmp0.z,0,0,0);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)(tmp0.w,0,0,0);\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" #else\n"
|
|
" #if N_LEAVE_NUM == 3\n"
|
|
" {\n"
|
|
" B[0]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT3(vload3(0,input_b_offset+kindex*N)),0);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT3(vload3(0,input_b_offset+(kindex+1)*N)),0);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT3(vload3(0,input_b_offset+(kindex+2)*N)),0);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT3(vload3(0,input_b_offset+(kindex+3)*N)),0);\n"
|
|
" }\n"
|
|
" #elif N_LEAVE_NUM == 2\n"
|
|
" {\n"
|
|
" B[0]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT2(vload2(0,input_b_offset+kindex*N)),0,0);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT2(vload2(0,input_b_offset+(kindex+1)*N)),0,0);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT2(vload2(0,input_b_offset+(kindex+2)*N)),0,0);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)(CONVERT_COMPUTE_FLOAT2(vload2(0,input_b_offset+(kindex+3)*N)),0,0);\n"
|
|
" }\n"
|
|
" #elif N_LEAVE_NUM == 1\n"
|
|
" {\n"
|
|
" B[0]=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)input_b_offset[kindex*N],0,0,0);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)input_b_offset[(kindex+1)*N],0,0,0);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)input_b_offset[(kindex+2)*N],0,0,0);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)input_b_offset[(kindex+3)*N],0,0,0);\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" #endif\n"
|
|
" } else\n"
|
|
" #endif\n"
|
|
" {\n"
|
|
" #ifdef TRANSPOSE_B\n"
|
|
" {\n"
|
|
" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex));\n"
|
|
" COMPUTE_FLOAT4 tmp1=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+K));\n"
|
|
" COMPUTE_FLOAT4 tmp2=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+2*K));\n"
|
|
" COMPUTE_FLOAT4 tmp3=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+3*K));\n"
|
|
" \n"
|
|
" B[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n"
|
|
" B[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n"
|
|
" B[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n"
|
|
" B[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n"
|
|
" }\n"
|
|
" #else\n"
|
|
" B[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex*N));\n"
|
|
" B[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+1)*N));\n"
|
|
" B[2]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+2)*N));\n"
|
|
" B[3]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+3)*N));\n"
|
|
" #endif\n"
|
|
" }\n"
|
|
" \n"
|
|
" #pragma unroll\n"
|
|
" for (int vec_m=0; vec_m<4; ++vec_m){\n"
|
|
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].x,B[0],out[vec_m]);\n"
|
|
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].y,B[1],out[vec_m]);\n"
|
|
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].z,B[2],out[vec_m]);\n"
|
|
" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].w,B[3],out[vec_m]);\n"
|
|
" }\n"
|
|
" }\n"
|
|
" #ifdef K_LEAVE\n"
|
|
" for (int k=loop_end << 2; k<K; ++k){\n"
|
|
" COMPUTE_FLOAT4 A; // m4\n"
|
|
" COMPUTE_FLOAT4 B; // n4\n"
|
|
" #ifdef M_LEAVE\n"
|
|
" if(idm+3 >= M){\n"
|
|
" #ifdef TRANSPOSE_A\n"
|
|
" #if M_LEAVE_NUM == 3\n"
|
|
" A.s012=CONVERT_COMPUTE_FLOAT3(vload3(0,input_a_offset+k*M));\n"
|
|
" #elif M_LEAVE_NUM == 2\n"
|
|
" A.s01=CONVERT_COMPUTE_FLOAT2(vload2(0,input_a_offset+k*M));\n"
|
|
" #elif M_LEAVE_NUM == 1\n"
|
|
" A.s0=(COMPUTE_FLOAT)input_a_offset[k*M];\n"
|
|
" #endif\n"
|
|
" #else\n"
|
|
" A.x=(COMPUTE_FLOAT)input_a_offset[k];\n"
|
|
" #if M_LEAVE_NUM >= 2\n"
|
|
" A.y=(COMPUTE_FLOAT)input_a_offset[k+K];\n"
|
|
" #endif\n"
|
|
" #if M_LEAVE_NUM >= 3\n"
|
|
" A.z=(COMPUTE_FLOAT)input_a_offset[k+2*K];\n"
|
|
" #endif\n"
|
|
" #endif\n"
|
|
" } else\n"
|
|
" #endif\n"
|
|
" {\n"
|
|
" #ifdef TRANSPOSE_A\n"
|
|
" A=CONVERT_COMPUTE_FLOAT4(vload4(0,input_a_offset+k*M));\n"
|
|
" #else\n"
|
|
" A.x=(COMPUTE_FLOAT)input_a_offset[k];\n"
|
|
" A.y=(COMPUTE_FLOAT)input_a_offset[k+K];\n"
|
|
" A.z=(COMPUTE_FLOAT)input_a_offset[k+2*K];\n"
|
|
" A.w=(COMPUTE_FLOAT)input_a_offset[k+3*K];\n"
|
|
" #endif\n"
|
|
" }\n"
|
|
" \n"
|
|
" #ifdef N_LEAVE\n"
|
|
" if(idn+3 >= N){\n"
|
|
" #ifdef TRANSPOSE_B\n"
|
|
" B.x=(COMPUTE_FLOAT)input_b_offset[k];\n"
|
|
" #if N_LEAVE_NUM >= 2\n"
|
|
" B.y=(COMPUTE_FLOAT)input_b_offset[k+K];\n"
|
|
" #endif\n"
|
|
" #if N_LEAVE_NUM >= 3\n"
|
|
" B.z=(COMPUTE_FLOAT)input_b_offset[k+2*K];\n"
|
|
" #endif\n"
|
|
" #else\n"
|
|
" #if N_LEAVE_NUM == 3\n"
|
|
" B.s012=CONVERT_COMPUTE_FLOAT3(vload3(0,input_b_offset+k*N));\n"
|
|
" #elif N_LEAVE_NUM == 2\n"
|
|
" B.s01=CONVERT_COMPUTE_FLOAT2(vload2(0,input_b_offset+k*N));\n"
|
|
" #elif N_LEAVE_NUM == 1\n"
|
|
" B.s0=(COMPUTE_FLOAT)input_b_offset[k*N];\n"
|
|
" #endif\n"
|
|
" #endif\n"
|
|
" } else\n"
|
|
" #endif\n"
|
|
" {\n"
|
|
" #ifdef TRANSPOSE_B\n"
|
|
" B.x=(COMPUTE_FLOAT)input_b_offset[k];\n"
|
|
" B.y=(COMPUTE_FLOAT)input_b_offset[k+K];\n"
|
|
" B.z=(COMPUTE_FLOAT)input_b_offset[k+2*K];\n"
|
|
" B.w=(COMPUTE_FLOAT)input_b_offset[k+3*K];\n"
|
|
" #else\n"
|
|
" B=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+k*N));\n"
|
|
" #endif\n"
|
|
" }\n"
|
|
" out[0]=mad((COMPUTE_FLOAT4)A.x,B,out[0]);\n"
|
|
" out[1]=mad((COMPUTE_FLOAT4)A.y,B,out[1]);\n"
|
|
" out[2]=mad((COMPUTE_FLOAT4)A.z,B,out[2]);\n"
|
|
" out[3]=mad((COMPUTE_FLOAT4)A.w,B,out[3]);\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" \n"
|
|
" \n"
|
|
" const int out_offset=idm*N+idn;\n"
|
|
" #ifdef M_LEAVE\n"
|
|
" if(idm+3 >= M){\n"
|
|
" #ifdef N_LEAVE\n"
|
|
" if(idn+3 >= N){\n"
|
|
" for (int vec_m=0; vec_m<M-idm; ++vec_m){\n"
|
|
" COMPUTE_FLOAT *out_ptr=(COMPUTE_FLOAT*)&out[vec_m];\n"
|
|
" for(int vec_n=0; vec_n<N-idn; ++vec_n){\n"
|
|
" output_c[out_offset+vec_m*N+vec_n]=out_ptr[vec_n];\n"
|
|
" }\n"
|
|
" }\n"
|
|
" } else {\n"
|
|
" #endif\n"
|
|
" for (int vec_m=0; vec_m<M-idm; ++vec_m){\n"
|
|
" vstore4(CONVERT_FLOAT4(out[vec_m]),0,output_c+out_offset+vec_m*N);\n"
|
|
" }\n"
|
|
" \n"
|
|
" #ifdef N_LEAVE\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" } else{\n"
|
|
" #endif\n"
|
|
" #ifdef N_LEAVE\n"
|
|
" if(idn+3 >= N){\n"
|
|
" #pragma unroll\n"
|
|
" for (int vec_m=0; vec_m<4; ++vec_m){\n"
|
|
" COMPUTE_FLOAT *out_ptr=(COMPUTE_FLOAT*)&out[vec_m];\n"
|
|
" for(int vec_n=0; vec_n<N-idn; ++vec_n){\n"
|
|
" output_c[out_offset+vec_m*N+vec_n]=out_ptr[vec_n];\n"
|
|
" }\n"
|
|
" }\n"
|
|
" } else {\n"
|
|
" #endif\n"
|
|
" #pragma unroll\n"
|
|
" for (int vec_m=0; vec_m<4; ++vec_m){\n"
|
|
" vstore4(CONVERT_FLOAT4(out[vec_m]),0,output_c+out_offset+vec_m*N);\n"
|
|
" }\n"
|
|
" #ifdef N_LEAVE\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
" #ifdef M_LEAVE\n"
|
|
" }\n"
|
|
" #endif\n"
|
|
"}\n"
|
|
;
|
|
#endif
|
|
}
|