MNN/source/backend/metal/AllShader.cpp

2422 lines
92 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "AllShader.hpp"
const char* shader_MetalReLU6_metal =
"struct Param {\n"
" float minV;\n"
" float maxV;\n"
" int size;\n"
" int remain;\n"
"};\n"
"kernel void relu6(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant Param &p [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<p.size) {\n"
" out[int(gid.x)]=clamp(in[int(gid.x)],(M4)p.minV,(M4)p.maxV);\n"
" }\n"
"}\n"
"kernel void relu(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant Param &p [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<p.size) {\n"
" auto V=in[int(gid.x)];\n"
" out[int(gid.x)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(p.minV);\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionDepthwise_metal =
"struct conv_dw_cst {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int slice;\n"
" int batch;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv_depthwise(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_dw_cst& cst [[buffer(2)]],\n"
" const device M4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice*cst.batch) return;\n"
" \n"
" int oz=gid.z/cst.batch;\n"
" int offset_x=(int)gid.x*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)gid.y*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" auto z_wt=wt+(int)oz*cst.kernel_size;\n"
" auto z_in=in+(int)gid.z*cst.input_size;\n"
" auto z_out=out+(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" FLOAT4 result=FLOAT4(biasTerms[oz]);\n"
" for (auto ky=sy,y=offset_y; ky<ey; ky++,y += cst.dilation_y) {\n"
" for (auto kx=sx,x=offset_x; kx<ex; kx++,x += cst.dilation_x) {\n"
" auto wt4=z_wt[ky*cst.kernel_x+kx];\n"
" auto in4=z_in[ y*cst.input_width+x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" *z_out=activate((M4)result,cst.activation);\n"
"}\n"
;
const char* shader_MetalConvolutionActivation_metal =
"typedef enum : int {\n"
" None=0,\n"
" ReLU=1,\n"
" ReLU6=2,\n"
"} conv_activation_type;\n"
"inline M4 activate(M4 V,conv_activation_type type) {\n"
" switch (type) {\n"
" case ReLU:\n"
" return max(V,(M4)0);\n"
" case ReLU6:\n"
" return clamp(V,(M4)0,(M4)6);\n"
" default: // None\n"
" return V;\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolution_metal =
"#define CONV_UNROLL (4)\n"
"#define CONV_MUL_PACK_W2(x,y) "
" x += FLOAT4(in00*k00);"
" y += FLOAT4(in01*k00);"
" x += FLOAT4(in01*k01);"
" y += FLOAT4(in02*k01);"
" x += FLOAT4(in02*k02);"
" y += FLOAT4(in03*k02);"
" "
" x += FLOAT4(in10*k10);"
" y += FLOAT4(in11*k10);"
" x += FLOAT4(in11*k11);"
" y += FLOAT4(in12*k11);"
" x += FLOAT4(in12*k12);"
" y += FLOAT4(in13*k12);"
" "
" x += FLOAT4(in20*k20);"
" y += FLOAT4(in21*k20);"
" x += FLOAT4(in21*k21);"
" y += FLOAT4(in22*k21);"
" x += FLOAT4(in22*k22);"
" y += FLOAT4(in23*k22);\n"
" \n"
"#define CONV_NEXT_FLT "
" z_wt += ws; "
" "
" k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];"
" k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];"
" k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n"
"#define CONV_MUL_PACK_H2(x,y) "
" x += FLOAT4(in10*k00);"
" y += FLOAT4(in11*k00);"
" x += FLOAT4(in11*k01);"
" y += FLOAT4(in12*k01);"
" x += FLOAT4(in12*k02);"
" y += FLOAT4(in13*k02);"
" "
" x += FLOAT4(in20*k10);"
" y += FLOAT4(in21*k10);"
" x += FLOAT4(in21*k11);"
" y += FLOAT4(in22*k11);"
" x += FLOAT4(in22*k12);"
" y += FLOAT4(in23*k12);"
" "
" x += FLOAT4(in30*k20);"
" y += FLOAT4(in31*k20);"
" x += FLOAT4(in31*k21);"
" y += FLOAT4(in32*k21);"
" x += FLOAT4(in32*k22);"
" y += FLOAT4(in33*k22);\n"
"struct conv_constants {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int batch;\n"
" int oz_size;\n"
" int threadgroup_input_slice;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation; \n"
"};\n"
"kernel void conv(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" \n"
" int offset_x=(int)idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_size+(int)idx_c*cst.batch*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result=FLOAT4(biasTerms[idx_c]);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" auto in4=z_in[z*cst.input_size*cst.batch+y*dilation_h+x*cst.dilation_x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" *z_out=activate(M4(result),cst.activation);\n"
"}\n"
"kernel void convk3s1d1p1_w2z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n"
" bool3 valids=uz.yzw<cst.output_slice;\n"
" bool valid_x=(int)(gid.x*2+1)<cst.output_width;\n"
" int offset_x=(int)gid.x*2-cst.pad_x;\n"
" int offset_y=(int)gid.y-cst.pad_y;\n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_flt=wt+uz[0]*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" FLOAT4 result4=0,result5=0,result6=0,result7=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += (cst.input_size*cst.batch)) {\n"
" auto in00=(offset_x<0 || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+0);\n"
" auto in01=(offset_x+1>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+1);\n"
" auto in02=(offset_x+2>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+2);\n"
" auto in03=(offset_x+3>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+3);\n"
" auto in10=(offset_x<0 || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+0);\n"
" auto in11=(offset_x+1>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+1);\n"
" auto in12=(offset_x+2>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+2);\n"
" auto in13=(offset_x+3>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+3);\n"
" \n"
" auto in20=(offset_x<0 || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+0);\n"
" auto in21=(offset_x+1>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+1);\n"
" auto in22=(offset_x+2>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+2);\n"
" auto in23=(offset_x+3>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+3);\n"
" \n"
" auto z_wt=z_flt;\n"
" auto k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];\n"
" auto k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];\n"
" auto k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n"
" CONV_MUL_PACK_W2(result0,result4);\n"
" if (valids[0]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result1,result5);\n"
" }\n"
" if (valids[1]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result2,result6);\n"
" }\n"
" if (valids[2]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result3,result7);\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result4+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" }\n"
" if (valids[0]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result5+FLOAT4(biasTerms[uz[1]])),cst.activation);\n"
" }\n"
" }\n"
" if (valids[1]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result6+FLOAT4(biasTerms[uz[2]])),cst.activation);\n"
" }\n"
" }\n"
" if (valids[2]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result7+FLOAT4(biasTerms[uz[3]])),cst.activation);\n"
" }\n"
" }\n"
"}\n"
"kernel void conv_s1d1p0_w2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" bool valid=(idx_w+1<cst.output_width);\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result1=result0;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<cst.kernel_y; y++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x];\n"
" auto in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width];\n"
" result0 += FLOAT4(in4_0*wt4);\n"
" for (auto x=1; x<cst.kernel_x; x++) {\n"
" in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width+x];\n"
" result1 += FLOAT4(in4_0*wt4);\n"
" wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" result0 += FLOAT4(in4_0*wt4);\n"
" }\n"
" in4_0=z_in[z*cst.input_size+y*cst.input_width+cst.kernel_x];\n"
" result1 += FLOAT4(in4_0*wt4);\n"
" }\n"
" }\n"
" *z_out=activate(M4(result0),cst.activation);\n"
" if(valid) { *(z_out+1)=activate(M4(result1),cst.activation);}\n"
"}\n"
"kernel void conv_s1d1p0_w4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x << 2;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" \n"
" if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" int3 uz=idx_w+int3(1,2,3);\n"
" bool3 valids=uz.xyz<cst.output_width;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
" FLOAT4 result3=result0;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<cst.kernel_y; y++) {\n"
" auto wt_base=z_wt+z*cst.kernel_size+y*cst.kernel_x;\n"
" auto wt4_0=wt_base[0];\n"
" auto wt4_1=wt_base[1];\n"
" auto wt4_2=wt_base[2];\n"
" auto z_in_base=z_in+z*cst.batch*cst.input_size+y*cst.input_width;\n"
" auto in4_0=z_in_base[0];\n"
" result0 += FLOAT4(in4_0*wt4_0);\n"
" \n"
" in4_0=z_in_base[1];\n"
" result0 += FLOAT4(in4_0*wt4_1);\n"
" result1 += FLOAT4(in4_0*wt4_0);\n"
" in4_0=z_in_base[2];\n"
" result0 += FLOAT4(in4_0*wt4_2);\n"
" result1 += FLOAT4(in4_0*wt4_1);\n"
" result2 += FLOAT4(in4_0*wt4_0);\n"
" in4_0=z_in_base[3];\n"
" result1 += FLOAT4(in4_0*wt4_2);\n"
" result2 += FLOAT4(in4_0*wt4_1);\n"
" result3 += FLOAT4(in4_0*wt4_0);\n"
" \n"
" in4_0=z_in_base[4];\n"
" result2 += FLOAT4(in4_0*wt4_2);\n"
" result3 += FLOAT4(in4_0*wt4_1);\n"
" in4_0=z_in_base[5];\n"
" result3 += FLOAT4(in4_0*wt4_2);\n"
" }\n"
" }\n"
" *z_out=activate(M4(result0),cst.activation);\n"
" if(valids[0]) { *(z_out+1)=activate(M4(result1),cst.activation);}\n"
" if(valids[1]) { *(z_out+2)=activate(M4(result2),cst.activation);}\n"
" if(valids[2]) { *(z_out+3)=activate(M4(result3),cst.activation);}\n"
"}\n"
"kernel void conv_z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" if (idx_b >= cst.batch || idx_c*4 >= cst.output_slice) return;\n"
" int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n"
" bool3 valids=uz.yzw<cst.output_slice;\n"
" \n"
" int offset_x=idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
" auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n"
" /* true */ result0 += FLOAT4(in4**x_wt);\n"
" if (valids[0]) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n"
" if (valids[1]) { x_wt += ws; result2 += FLOAT4(in4**x_wt); }\n"
" if (valids[2]) { x_wt += ws; result3 += FLOAT4(in4**x_wt); }\n"
" }\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if (valids[0]) { z_out += cst.output_size; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
" if (valids[1]) { z_out += cst.output_size; *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation); }\n"
" if (valids[2]) { z_out += cst.output_size; *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation); }\n"
"}\n"
"kernel void conv_z2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" if (idx_b >= cst.batch || idx_c*2 >= cst.output_slice) return;\n"
" int2 uz=idx_c*2+int2(0,1);\n"
" bool valids=uz.y<cst.output_slice;\n"
" \n"
" int offset_x=idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
" auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n"
" /* true */ result0 += FLOAT4(in4**x_wt);\n"
" if (valids) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n"
" }\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if (valids) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
"}\n"
;
const char* shader_MetalReduction_metal =
"struct reduce_shape {\n"
" int outside_size;\n"
" int axis_size;\n"
" int inside_size;\n"
" int outside_step;\n"
"};\n"
"template <typename M,typename T>\n"
"static inline void reduce_mean(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=0;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer += M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer/s.axis_size);\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_sum(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=0;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer += M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_min(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" T summer=*axis_in; axis_in += s.inside_size;\n"
" for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer=min(summer,*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_max(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" T summer=*axis_in; axis_in += s.inside_size;\n"
" for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer=max(summer,*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_prod(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=1;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer *= M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n"
"}\n"
"#define define_reduce(name) "
"kernel void reduce_##name##_f(const device M *in [[buffer(0)]],"
" device M *out [[buffer(1)]],"
" constant reduce_shape &s [[buffer(2)]],"
" uint2 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<FLOAT,M>(in,out,s,gid); "
"} "
"kernel void reduce_##name##_s(const device int *in [[buffer(0)]],"
" device int *out [[buffer(1)]],"
" constant reduce_shape &s [[buffer(2)]],"
" uint2 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<int,int>(in,out,s,gid); "
"}\n"
"define_reduce(mean);\n"
"define_reduce(sum);\n"
"define_reduce(min);\n"
"define_reduce(max);\n"
"define_reduce(prod);\n"
;
const char* shader_MetalSoftmax_metal =
"struct softmax_shape {\n"
" int inside_size;\n"
" int axis_length;\n"
" int outside_size;\n"
" int flat_length;\n"
"};\n"
"static inline float softmax_max4(float4 V) {\n"
" return max(max(V[0],V[1]),max(V[2],V[3]));\n"
"}\n"
"static inline float softmax_sum4(float4 V) {\n"
" return V[0]+V[1]+V[2]+V[3];\n"
"}\n"
"static inline float4 softmax_filter(float4 V,int z,int limit) {\n"
" return select(0,V,z*4+int4(0,1,2,3)<limit);\n"
"}\n"
"kernel void softmax_plane(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" \n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" \n"
" // get max\n"
" auto max1=axis_in[0];\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max1=max(max1,axis_in[i*s.inside_size]);\n"
" }\n"
" \n"
" // get sum\n"
" float sum1=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum1 += float(exp(axis_in[i*s.inside_size]-max1));\n"
" }\n"
" \n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M(exp(float(axis_in[i*s.inside_size]-max1))/sum1);\n"
" }\n"
"}\n"
"kernel void softmax_on_reorder(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" \n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" // get max\n"
" auto max4=softmax_filter(float4(axis_in[0]),0,s.flat_length);\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max4=max(max4,softmax_filter(float4(axis_in[i*s.inside_size]),i,s.flat_length));\n"
" }\n"
" float max1=softmax_max4(max4);\n"
" \n"
" // get sum\n"
" float4 sum4=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum4 += softmax_filter(exp(float4(axis_in[i*s.inside_size]-max1)),i,s.flat_length);\n"
" }\n"
" float sum1=softmax_sum4(sum4);\n"
" \n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size])-max1)/sum1);\n"
" }\n"
"}\n"
"kernel void softmax_off_reorder(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" // get max\n"
" auto max4=axis_in[0];\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max4=max(max4,axis_in[i*s.inside_size]);\n"
" }\n"
" // get sum\n"
" float4 sum4=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum4 += exp(float4(axis_in[i*s.inside_size]-max4));\n"
" }\n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size]-max4))/sum4);\n"
" }\n"
"}\n"
;
const char* shader_MetalLayerNorm_metal =
"struct layernorm_constants {\n"
" int inside;\n"
" int outside;\n"
" float eps;\n"
" int has_gamma_beta;\n"
"};\n"
"kernel void layernorm_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float *gamma [[buffer(3)]],\n"
" const device float *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside;\n"
" auto out_data=out+gid.y*cst.inside;\n"
" float mean;\n"
" float sum=0.0f;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" sum += in_data[i];\n"
" }\n"
" mean=sum/cst.inside;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" float dis=(in_data[i]-mean);\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float norm=var*((float)in_data[gid.x]-mean);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float4 *gamma [[buffer(3)]],\n"
" const device float4 *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside/4;\n"
" auto out_data=out+gid.y*cst.inside/4;\n"
" float mean;\n"
" float sum=0.0f;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" sum += in_data[i].x;\n"
" sum += in_data[i].y;\n"
" sum += in_data[i].z;\n"
" sum += in_data[i].w;\n"
" }\n"
" mean=sum/cst.inside;\n"
" \n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" float dis=(in_data[i].x-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].y-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].z-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].w-mean);\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float4 norm=var*((float4)in_data[gid.x]-mean);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M4)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x1_rms(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float *gamma [[buffer(3)]],\n"
" const device float *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside;\n"
" auto out_data=out+gid.y*cst.inside;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" float dis=in_data[i];\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float norm=var*((float)in_data[gid.x]);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x4_rms(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float4 *gamma [[buffer(3)]],\n"
" const device float4 *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside/4;\n"
" auto out_data=out+gid.y*cst.inside/4;\n"
" float square_sum=0.0f;\n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" float dis=in_data[i].x;\n"
" square_sum += dis*dis;\n"
" dis=in_data[i].y;\n"
" square_sum += dis*dis;\n"
" dis=in_data[i].z;\n"
" square_sum += dis*dis;\n"
" dis=in_data[i].w;\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float4 norm=var*((float4)in_data[gid.x]);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M4)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_m1x4_rms(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float4 *gamma [[buffer(3)]],\n"
" const device float4 *beta [[buffer(4)]],\n"
" uint gid [[threadgroup_position_in_grid]],\n"
" uint tiisg[[thread_index_in_simdgroup]],\n"
" uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
" int total_idx=(gid*4+sgitg);\n"
" int in_idx=total_idx % (cst.inside/4);\n"
" int out_idx=total_idx/(cst.inside/4);\n"
" auto in_data=in+out_idx*cst.inside/4;\n"
" auto out_data=out+out_idx*cst.inside/4;\n"
" float square_sum=0.0f;\n"
" for(int i=tiisg; i<cst.inside/4; i+=SIMD_GROUP_WIDTH) {\n"
" M4 data=in_data[i];\n"
" float dis=data.x;\n"
" square_sum += dis*dis;\n"
" dis=data.y;\n"
" square_sum += dis*dis;\n"
" dis=data.z;\n"
" square_sum += dis*dis;\n"
" dis=data.w;\n"
" square_sum += dis*dis;\n"
" }\n"
" square_sum=simd_sum(square_sum);\n"
" \n"
" if(tiisg == 0) {\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float4 norm=var*((float4)in_data[in_idx]);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[in_idx]=(M4)(norm*gamma[in_idx]+beta[in_idx]);\n"
" } else {\n"
" out_data[in_idx]=(M4)(norm);\n"
" }\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionWinograd_metal =
"struct winograd_constants {\n"
" int4 input_shape;\n"
" int4 output_shape;\n"
" int pad_x;\n"
" int pad_y;\n"
" int unit_width;\n"
" int unit_height;\n"
" int unit;\n"
" conv_activation_type activation;\n"
"};\n"
"static inline M4 get_input(const device M4 *input,int x,int y,constant winograd_constants &cst) {\n"
" return x<cst.input_shape.x && y<cst.input_shape.y && x >= 0 && y >= 0 ? input[x+y*cst.input_shape.x] : 0;\n"
"}\n"
"kernel void winograd_transform_source2_5_1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant winograd_constants &cst [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int ix=pos.x*cst.unit-cst.pad_x;\n"
" int iy=pos.y*cst.unit-cst.pad_y;\n"
" auto z_in=in+pos.z*cst.input_shape.x*cst.input_shape.y;\n"
" auto S00=get_input(z_in,ix+0,iy+0,cst);\n"
" auto S10=get_input(z_in,ix+1,iy+0,cst);\n"
" auto S20=get_input(z_in,ix+2,iy+0,cst);\n"
" auto S30=get_input(z_in,ix+3,iy+0,cst);\n"
" auto S40=get_input(z_in,ix+4,iy+0,cst);\n"
" auto S50=get_input(z_in,ix+5,iy+0,cst);\n"
" auto S01=get_input(z_in,ix+0,iy+1,cst);\n"
" auto S11=get_input(z_in,ix+1,iy+1,cst);\n"
" auto S21=get_input(z_in,ix+2,iy+1,cst);\n"
" auto S31=get_input(z_in,ix+3,iy+1,cst);\n"
" auto S41=get_input(z_in,ix+4,iy+1,cst);\n"
" auto S51=get_input(z_in,ix+5,iy+1,cst);\n"
" auto S02=get_input(z_in,ix+0,iy+2,cst);\n"
" auto S12=get_input(z_in,ix+1,iy+2,cst);\n"
" auto S22=get_input(z_in,ix+2,iy+2,cst);\n"
" auto S32=get_input(z_in,ix+3,iy+2,cst);\n"
" auto S42=get_input(z_in,ix+4,iy+2,cst);\n"
" auto S52=get_input(z_in,ix+5,iy+2,cst);\n"
" auto S03=get_input(z_in,ix+0,iy+3,cst);\n"
" auto S13=get_input(z_in,ix+1,iy+3,cst);\n"
" auto S23=get_input(z_in,ix+2,iy+3,cst);\n"
" auto S33=get_input(z_in,ix+3,iy+3,cst);\n"
" auto S43=get_input(z_in,ix+4,iy+3,cst);\n"
" auto S53=get_input(z_in,ix+5,iy+3,cst);\n"
" auto S04=get_input(z_in,ix+0,iy+4,cst);\n"
" auto S14=get_input(z_in,ix+1,iy+4,cst);\n"
" auto S24=get_input(z_in,ix+2,iy+4,cst);\n"
" auto S34=get_input(z_in,ix+3,iy+4,cst);\n"
" auto S44=get_input(z_in,ix+4,iy+4,cst);\n"
" auto S54=get_input(z_in,ix+5,iy+4,cst);\n"
" auto S05=get_input(z_in,ix+0,iy+5,cst);\n"
" auto S15=get_input(z_in,ix+1,iy+5,cst);\n"
" auto S25=get_input(z_in,ix+2,iy+5,cst);\n"
" auto S35=get_input(z_in,ix+3,iy+5,cst);\n"
" auto S45=get_input(z_in,ix+4,iy+5,cst);\n"
" auto S55=get_input(z_in,ix+5,iy+5,cst);\n"
" auto m00=+S00-1.25*S02+0.25*S04;\n"
" auto m10=+S10-1.25*S12+0.25*S14;\n"
" auto m20=+S20-1.25*S22+0.25*S24;\n"
" auto m30=+S30-1.25*S32+0.25*S34;\n"
" auto m40=+S40-1.25*S42+0.25*S44;\n"
" auto m50=+S50-1.25*S52+0.25*S54;\n"
" auto m01=+0.666667*S01+0.666667*S02-0.166667*S03-0.166667*S04;\n"
" auto m11=+0.666667*S11+0.666667*S12-0.166667*S13-0.166667*S14;\n"
" auto m21=+0.666667*S21+0.666667*S22-0.166667*S23-0.166667*S24;\n"
" auto m31=+0.666667*S31+0.666667*S32-0.166667*S33-0.166667*S34;\n"
" auto m41=+0.666667*S41+0.666667*S42-0.166667*S43-0.166667*S44;\n"
" auto m51=+0.666667*S51+0.666667*S52-0.166667*S53-0.166667*S54;\n"
" auto m02=-0.666667*S01+0.666667*S02+0.166667*S03-0.166667*S04;\n"
" auto m12=-0.666667*S11+0.666667*S12+0.166667*S13-0.166667*S14;\n"
" auto m22=-0.666667*S21+0.666667*S22+0.166667*S23-0.166667*S24;\n"
" auto m32=-0.666667*S31+0.666667*S32+0.166667*S33-0.166667*S34;\n"
" auto m42=-0.666667*S41+0.666667*S42+0.166667*S43-0.166667*S44;\n"
" auto m52=-0.666667*S51+0.666667*S52+0.166667*S53-0.166667*S54;\n"
" auto m03=-0.0833333*S01-0.0416667*S02+0.0833333*S03+0.0416667*S04;\n"
" auto m13=-0.0833333*S11-0.0416667*S12+0.0833333*S13+0.0416667*S14;\n"
" auto m23=-0.0833333*S21-0.0416667*S22+0.0833333*S23+0.0416667*S24;\n"
" auto m33=-0.0833333*S31-0.0416667*S32+0.0833333*S33+0.0416667*S34;\n"
" auto m43=-0.0833333*S41-0.0416667*S42+0.0833333*S43+0.0416667*S44;\n"
" auto m53=-0.0833333*S51-0.0416667*S52+0.0833333*S53+0.0416667*S54;\n"
" auto m04=+0.0833333*S01-0.0416667*S02-0.0833333*S03+0.0416667*S04;\n"
" auto m14=+0.0833333*S11-0.0416667*S12-0.0833333*S13+0.0416667*S14;\n"
" auto m24=+0.0833333*S21-0.0416667*S22-0.0833333*S23+0.0416667*S24;\n"
" auto m34=+0.0833333*S31-0.0416667*S32-0.0833333*S33+0.0416667*S34;\n"
" auto m44=+0.0833333*S41-0.0416667*S42-0.0833333*S43+0.0416667*S44;\n"
" auto m54=+0.0833333*S51-0.0416667*S52-0.0833333*S53+0.0416667*S54;\n"
" auto m05=+4.0*S01-5.0*S03+S05;\n"
" auto m15=+4.0*S11-5.0*S13+S15;\n"
" auto m25=+4.0*S21-5.0*S23+S25;\n"
" auto m35=+4.0*S31-5.0*S33+S35;\n"
" auto m45=+4.0*S41-5.0*S43+S45;\n"
" auto m55=+4.0*S51-5.0*S53+S55;\n"
" int dst_x_origin=pos.z;\n"
" int dst_y_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_y_stride=cst.input_shape.z*4;\n"
" int dst_y=dst_y_origin/4;\n"
" int dst_x=dst_y_origin % 4+4*dst_x_origin;\n"
" int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int stride=src_height*dst_y_stride;\n"
" auto xy_out=out+dst_y*dst_y_stride+dst_x;\n"
" *xy_out=+m00-1.25*m20+0.25*m40;\n"
" xy_out += stride; *xy_out=+0.666667*m10+0.666667*m20-0.166667*m30-0.166667*m40;\n"
" xy_out += stride; *xy_out=-0.666667*m10+0.666667*m20+0.166667*m30-0.166667*m40;\n"
" xy_out += stride; *xy_out=-0.0833333*m10-0.0416667*m20+0.0833333*m30+0.0416667*m40;\n"
" xy_out += stride; *xy_out=+0.0833333*m10-0.0416667*m20-0.0833333*m30+0.0416667*m40;\n"
" xy_out += stride; *xy_out=+4.0*m10-5.0*m30+m50;\n"
" xy_out += stride; *xy_out=+m01-1.25*m21+0.25*m41;\n"
" xy_out += stride; *xy_out=+0.666667*m11+0.666667*m21-0.166667*m31-0.166667*m41;\n"
" xy_out += stride; *xy_out=-0.666667*m11+0.666667*m21+0.166667*m31-0.166667*m41;\n"
" xy_out += stride; *xy_out=-0.0833333*m11-0.0416667*m21+0.0833333*m31+0.0416667*m41;\n"
" xy_out += stride; *xy_out=+0.0833333*m11-0.0416667*m21-0.0833333*m31+0.0416667*m41;\n"
" xy_out += stride; *xy_out=+4.0*m11-5.0*m31+m51;\n"
" xy_out += stride; *xy_out=+m02-1.25*m22+0.25*m42;\n"
" xy_out += stride; *xy_out=+0.666667*m12+0.666667*m22-0.166667*m32-0.166667*m42;\n"
" xy_out += stride; *xy_out=-0.666667*m12+0.666667*m22+0.166667*m32-0.166667*m42;\n"
" xy_out += stride; *xy_out=-0.0833333*m12-0.0416667*m22+0.0833333*m32+0.0416667*m42;\n"
" xy_out += stride; *xy_out=+0.0833333*m12-0.0416667*m22-0.0833333*m32+0.0416667*m42;\n"
" xy_out += stride; *xy_out=+4.0*m12-5.0*m32+m52;\n"
" xy_out += stride; *xy_out=+m03-1.25*m23+0.25*m43;\n"
" xy_out += stride; *xy_out=+0.666667*m13+0.666667*m23-0.166667*m33-0.166667*m43;\n"
" xy_out += stride; *xy_out=-0.666667*m13+0.666667*m23+0.166667*m33-0.166667*m43;\n"
" xy_out += stride; *xy_out=-0.0833333*m13-0.0416667*m23+0.0833333*m33+0.0416667*m43;\n"
" xy_out += stride; *xy_out=+0.0833333*m13-0.0416667*m23-0.0833333*m33+0.0416667*m43;\n"
" xy_out += stride; *xy_out=+4.0*m13-5.0*m33+m53;\n"
" xy_out += stride; *xy_out=+m04-1.25*m24+0.25*m44;\n"
" xy_out += stride; *xy_out=+0.666667*m14+0.666667*m24-0.166667*m34-0.166667*m44;\n"
" xy_out += stride; *xy_out=-0.666667*m14+0.666667*m24+0.166667*m34-0.166667*m44;\n"
" xy_out += stride; *xy_out=-0.0833333*m14-0.0416667*m24+0.0833333*m34+0.0416667*m44;\n"
" xy_out += stride; *xy_out=+0.0833333*m14-0.0416667*m24-0.0833333*m34+0.0416667*m44;\n"
" xy_out += stride; *xy_out=+4.0*m14-5.0*m34+m54;\n"
" xy_out += stride; *xy_out=+m05-1.25*m25+0.25*m45;\n"
" xy_out += stride; *xy_out=+0.666667*m15+0.666667*m25-0.166667*m35-0.166667*m45;\n"
" xy_out += stride; *xy_out=-0.666667*m15+0.666667*m25+0.166667*m35-0.166667*m45;\n"
" xy_out += stride; *xy_out=-0.0833333*m15-0.0416667*m25+0.0833333*m35+0.0416667*m45;\n"
" xy_out += stride; *xy_out=+0.0833333*m15-0.0416667*m25-0.0833333*m35+0.0416667*m45;\n"
" xy_out += stride; *xy_out=+4.0*m15-5.0*m35+m55;\n"
" }\n"
"}\n"
"kernel void winograd_transform_source2_3_1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant winograd_constants &cst [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int ix=pos.x*cst.unit-cst.pad_x;\n"
" int iy=pos.y*cst.unit-cst.pad_y;\n"
" auto z_in=in+pos.z*cst.input_shape.x*cst.input_shape.y;\n"
" auto S00=get_input(z_in,ix+0,iy+0,cst);\n"
" auto S10=get_input(z_in,ix+1,iy+0,cst);\n"
" auto S20=get_input(z_in,ix+2,iy+0,cst);\n"
" auto S30=get_input(z_in,ix+3,iy+0,cst);\n"
" auto S01=get_input(z_in,ix+0,iy+1,cst);\n"
" auto S11=get_input(z_in,ix+1,iy+1,cst);\n"
" auto S21=get_input(z_in,ix+2,iy+1,cst);\n"
" auto S31=get_input(z_in,ix+3,iy+1,cst);\n"
" auto S02=get_input(z_in,ix+0,iy+2,cst);\n"
" auto S12=get_input(z_in,ix+1,iy+2,cst);\n"
" auto S22=get_input(z_in,ix+2,iy+2,cst);\n"
" auto S32=get_input(z_in,ix+3,iy+2,cst);\n"
" auto S03=get_input(z_in,ix+0,iy+3,cst);\n"
" auto S13=get_input(z_in,ix+1,iy+3,cst);\n"
" auto S23=get_input(z_in,ix+2,iy+3,cst);\n"
" auto S33=get_input(z_in,ix+3,iy+3,cst);\n"
" auto m00=+S00-S02;\n"
" auto m10=+S10-S12;\n"
" auto m20=+S20-S22;\n"
" auto m30=+S30-S32;\n"
" auto m01=+0.5*S01+0.5*S02;\n"
" auto m11=+0.5*S11+0.5*S12;\n"
" auto m21=+0.5*S21+0.5*S22;\n"
" auto m31=+0.5*S31+0.5*S32;\n"
" auto m02=-0.5*S01+0.5*S02;\n"
" auto m12=-0.5*S11+0.5*S12;\n"
" auto m22=-0.5*S21+0.5*S22;\n"
" auto m32=-0.5*S31+0.5*S32;\n"
" auto m03=-S01+S03;\n"
" auto m13=-S11+S13;\n"
" auto m23=-S21+S23;\n"
" auto m33=-S31+S33;\n"
" int dst_x_origin=pos.z;\n"
" int dst_y_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_y_stride=cst.input_shape.z*4;\n"
" int dst_y=dst_y_origin/4;\n"
" int dst_x=dst_y_origin % 4+4*dst_x_origin;\n"
" int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int stride=src_height*dst_y_stride;\n"
" auto xy_out=out+dst_y*dst_y_stride+dst_x;\n"
" *xy_out=+m00-m20;\n"
" xy_out += stride; *xy_out=+0.5*m10+0.5*m20;\n"
" xy_out += stride; *xy_out=-0.5*m10+0.5*m20;\n"
" xy_out += stride; *xy_out=-m10+m30;\n"
" xy_out += stride; *xy_out=+m01-m21;\n"
" xy_out += stride; *xy_out=+0.5*m11+0.5*m21;\n"
" xy_out += stride; *xy_out=-0.5*m11+0.5*m21;\n"
" xy_out += stride; *xy_out=-m11+m31;\n"
" xy_out += stride; *xy_out=+m02-m22;\n"
" xy_out += stride; *xy_out= +0.5*m12+0.5*m22;\n"
" xy_out += stride; *xy_out=-0.5*m12+0.5*m22;\n"
" xy_out += stride; *xy_out=-m12+m32;\n"
" xy_out += stride; *xy_out=+m03-m23;\n"
" xy_out += stride; *xy_out=+0.5*m13+0.5*m23;\n"
" xy_out += stride; *xy_out=-0.5*m13+0.5*m23;\n"
" xy_out += stride; *xy_out=-m13+m33;\n"
" }\n"
"}\n"
"static inline void set_output(constant winograd_constants &cst,device M4 *output,int x,int y,M4 V) {\n"
" output[y*cst.output_shape.x+x]=activate(V,cst.activation);\n"
"}\n"
"kernel void winograd_transform_dest2_5_1(const device M4 *in [[buffer(0)]],\n"
" const device M4 *biasTerms [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant winograd_constants &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int dst_x_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_x=dst_x_origin/4;\n"
" int dst_y=4*pos.z+dst_x_origin % 4;\n"
" int dst_y_stride=dst_w*36;\n"
" auto xy_in=in+dst_y*dst_y_stride+dst_x;\n"
" auto S00=*xy_in; xy_in += dst_w;\n"
" auto S10=*xy_in; xy_in += dst_w;\n"
" auto S20=*xy_in; xy_in += dst_w;\n"
" auto S30=*xy_in; xy_in += dst_w;\n"
" auto S40=*xy_in; xy_in += dst_w;\n"
" auto S50=*xy_in; xy_in += dst_w;\n"
" auto S01=*xy_in; xy_in += dst_w;\n"
" auto S11=*xy_in; xy_in += dst_w;\n"
" auto S21=*xy_in; xy_in += dst_w;\n"
" auto S31=*xy_in; xy_in += dst_w;\n"
" auto S41=*xy_in; xy_in += dst_w;\n"
" auto S51=*xy_in; xy_in += dst_w;\n"
" auto S02=*xy_in; xy_in += dst_w;\n"
" auto S12=*xy_in; xy_in += dst_w;\n"
" auto S22=*xy_in; xy_in += dst_w;\n"
" auto S32=*xy_in; xy_in += dst_w;\n"
" auto S42=*xy_in; xy_in += dst_w;\n"
" auto S52=*xy_in; xy_in += dst_w;\n"
" auto S03=*xy_in; xy_in += dst_w;\n"
" auto S13=*xy_in; xy_in += dst_w;\n"
" auto S23=*xy_in; xy_in += dst_w;\n"
" auto S33=*xy_in; xy_in += dst_w;\n"
" auto S43=*xy_in; xy_in += dst_w;\n"
" auto S53=*xy_in; xy_in += dst_w;\n"
" auto S04=*xy_in; xy_in += dst_w;\n"
" auto S14=*xy_in; xy_in += dst_w;\n"
" auto S24=*xy_in; xy_in += dst_w;\n"
" auto S34=*xy_in; xy_in += dst_w;\n"
" auto S44=*xy_in; xy_in += dst_w;\n"
" auto S54=*xy_in; xy_in += dst_w;\n"
" auto S05=*xy_in; xy_in += dst_w;\n"
" auto S15=*xy_in; xy_in += dst_w;\n"
" auto S25=*xy_in; xy_in += dst_w;\n"
" auto S35=*xy_in; xy_in += dst_w;\n"
" auto S45=*xy_in; xy_in += dst_w;\n"
" auto S55=*xy_in;\n"
" auto m00=+S00+S01+S02+S03+S04;\n"
" auto m10=+S10+S11+S12+S13+S14;\n"
" auto m20=+S20+S21+S22+S23+S24;\n"
" auto m30=+S30+S31+S32+S33+S34;\n"
" auto m40=+S40+S41+S42+S43+S44;\n"
" auto m50=+S50+S51+S52+S53+S54;\n"
" auto m01=+S01-S02+2.0*S03-2.0*S04+S05;\n"
" auto m11=+S11-S12+2.0*S13-2.0*S14+S15;\n"
" auto m21=+S21-S22+2.0*S23-2.0*S24+S25;\n"
" auto m31=+S31-S32+2.0*S33-2.0*S34+S35;\n"
" auto m41=+S41-S42+2.0*S43-2.0*S44+S45;\n"
" auto m51=+S51-S52+2.0*S53-2.0*S54+S55;\n"
" // write output\n"
" auto b4=biasTerms[int(pos.z)];\n"
" int oy=pos.y*cst.unit;\n"
" int ox=pos.x*cst.unit;\n"
" auto z_out=out+pos.z*cst.output_shape.x*cst.output_shape.y;\n"
" \n"
" /* if true */ {\n"
" set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20+m30+m40);\n"
" }\n"
" if (ox+1<cst.output_shape.x) {\n"
" set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+2.0*m30-2.0*m40+m50);\n"
" }\n"
" if (oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21+m31+m41);\n"
" }\n"
" if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+2.0*m31-2.0*m41+m51);\n"
" }\n"
" }\n"
"}\n"
"kernel void winograd_transform_dest2_3_1(const device M4 *in [[buffer(0)]],\n"
" const device M4 *biasTerms [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant winograd_constants &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int dst_x_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_x=dst_x_origin/4;\n"
" int dst_y=4*pos.z+dst_x_origin % 4;\n"
" int dst_y_stride=dst_w*16;\n"
" auto xy_in=in+dst_y*dst_y_stride+dst_x;\n"
" auto S00=*xy_in; xy_in += dst_w;\n"
" auto S10=*xy_in; xy_in += dst_w;\n"
" auto S20=*xy_in; xy_in += dst_w;\n"
" auto S30=*xy_in; xy_in += dst_w;\n"
" auto S01=*xy_in; xy_in += dst_w;\n"
" auto S11=*xy_in; xy_in += dst_w;\n"
" auto S21=*xy_in; xy_in += dst_w;\n"
" auto S31=*xy_in; xy_in += dst_w;\n"
" auto S02=*xy_in; xy_in += dst_w;\n"
" auto S12=*xy_in; xy_in += dst_w;\n"
" auto S22=*xy_in; xy_in += dst_w;\n"
" auto S32=*xy_in; xy_in += dst_w;\n"
" auto S03=*xy_in; xy_in += dst_w;\n"
" auto S13=*xy_in; xy_in += dst_w;\n"
" auto S23=*xy_in; xy_in += dst_w;\n"
" auto S33=*xy_in;\n"
" auto m00=+S00+S01+S02;\n"
" auto m10=+S10+S11+S12;\n"
" auto m20=+S20+S21+S22;\n"
" auto m30=+S30+S31+S32;\n"
" auto m01=+S01-S02+S03;\n"
" auto m11=+S11-S12+S13;\n"
" auto m21=+S21-S22+S23;\n"
" auto m31=+S31-S32+S33;\n"
" // write output\n"
" auto b4=biasTerms[int(pos.z)];\n"
" int oy=pos.y*cst.unit;\n"
" int ox=pos.x*cst.unit;\n"
" auto z_out=out+pos.z*cst.output_shape.x*cst.output_shape.y;\n"
" \n"
" /* if true */ {\n"
" set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20);\n"
" }\n"
" if (ox+1<cst.output_shape.x) {\n"
" set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+m30);\n"
" }\n"
" if (oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21);\n"
" }\n"
" if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+m31);\n"
" }\n"
" }\n"
"}\n"
;
const char* shader_MetalMatMul_metal =
"struct matmul_shape {\n"
" int4 mat_size;\n"
" int4 in_stride;\n"
"};\n"
"kernel void matmul(const device M *in0 [[buffer(0)]],\n"
" const device M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant matmul_shape &s [[buffer(3)]],\n"
" uint2 gid[[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n"
" auto off_in0=in0+int(gid.y)*s.in_stride.x;\n"
" auto off_in1=in1+int(gid.x)*s.in_stride.z;\n"
" FLOAT V=0.f;\n"
" for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n"
" V += FLOAT(*off_in0)*FLOAT(*off_in1);\n"
" }\n"
" out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V);\n"
" }\n"
"}\n"
"kernel void matmul_bias(const device M *in0 [[buffer(0)]],\n"
" const device M *in1 [[buffer(1)]],\n"
" const device M *biasValue [[buffer(2)]],\n"
" device M *out [[buffer(3)]],\n"
" constant matmul_shape &s [[buffer(4)]],\n"
" uint2 gid[[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n"
" auto off_in0=in0+int(gid.y)*s.in_stride.x;\n"
" auto off_in1=in1+int(gid.x)*s.in_stride.z;\n"
" FLOAT V=0.f;\n"
" for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n"
" V += FLOAT(*off_in0)*FLOAT(*off_in1);\n"
" }\n"
" out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V)+biasValue[(int)(gid.x)];\n"
" }\n"
"}\n"
;
const char* shader_MetalScale_metal =
"struct scale_shape {\n"
" int size;\n"
" int steps;\n"
" int batch;\n"
"};\n"
"kernel void scale_ca(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant scale_shape &s [[buffer(2)]],\n"
" const device float4 *scales [[buffer(3)]],\n"
" const device float4 *biasTerms [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.steps*s.batch) return;\n"
" int z=gid.y/s.batch;\n"
" out[int(gid.y)*s.size+int(gid.x)] =\n"
" in [int(gid.y)*s.size+int(gid.x)]*M4(scales[z])+M4(biasTerms[z]);\n"
"}\n"
;
const char* shader_MetalDeconvolution_metal =
"struct deconv_constants {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" \n"
" int delta_ky;\n"
" int delta_kx;\n"
" int delta_iy;\n"
" int delta_ix;\n"
" int has_bias;\n"
" int batch;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void deconv(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant deconv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" int b=gid.z % cst.batch;\n"
" int o=gid.z/cst.batch;\n"
" FLOAT4 result=FLOAT4(biasTerms[o]);\n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
" int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n"
" int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n"
" int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n"
" int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n"
" \n"
" if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n"
" int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n"
" int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n"
" int max_ky=(oy-min_sy)/cst.dilation_y;\n"
" int max_kx=(ox-min_sx)/cst.dilation_x;\n"
" int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto o_wt=wt+o*cst.input_slice*cst.kernel_size;\n"
" auto b_in=in+b*cst.input_size;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=o_wt[z*cst.kernel_size+ky*cst.kernel_x+kx];\n"
" auto in4=b_in[z*cst.input_size*cst.batch+iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" }\n"
" out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n"
"}\n"
"kernel void deconv_depthwise(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant deconv_constants& cst [[buffer(2)]],\n"
" const device M4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z/cst.batch)]);\n"
" \n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
" int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n"
" int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n"
" int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n"
" int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n"
" \n"
" if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n"
" int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n"
" int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n"
" int max_ky=(oy-min_sy)/cst.dilation_y;\n"
" int max_kx=(ox-min_sx)/cst.dilation_x;\n"
" int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto z_wt=wt+(int)gid.z*cst.kernel_size;\n"
" auto z_in=in+(int)gid.z*cst.input_size;\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=z_wt[ky*cst.kernel_x+kx];\n"
" auto in4=z_in[iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n"
"}\n"
;
const char* shader_MetalPooling_metal =
"struct pooling_sizes {\n"
" int input_width;\n"
" int input_height;\n"
" int output_width;\n"
" int output_height;\n"
" int slice;\n"
" int kernel_width;\n"
" int kernel_height;\n"
" int stride_width;\n"
" int stride_height;\n"
" int pad_width;\n"
" int pad_height;\n"
"};\n"
"kernel void pooling_max(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant pooling_sizes& s [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n"
" \n"
" int off_x=gid.x*s.stride_width-s.pad_width;\n"
" int off_y=gid.y*s.stride_height-s.pad_height;\n"
" int x_max=s.input_width-1;\n"
" int y_max=s.input_height-1;\n"
" int ex=off_x+s.kernel_width;\n"
" int ey=off_y+s.kernel_height;\n"
" \n"
" auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n"
" auto result=M4(z_in[clamp(off_y,0,y_max)*s.input_width+clamp(off_x,0,x_max)]);\n"
" for (int y=off_y; y<ey; y++) {\n"
" auto y_in=z_in+clamp(y,0,y_max)*s.input_width;\n"
" for (int x=off_x; x<ex; x++) {\n"
" result=max(result,y_in[clamp(x,0,x_max)]);\n"
" }\n"
" }\n"
" out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=result;\n"
"}\n"
"kernel void pooling_avg(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant pooling_sizes& s [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n"
" \n"
" int off_x=gid.x*s.stride_width-s.pad_width;\n"
" int off_y=gid.y*s.stride_height-s.pad_height;\n"
" int sx=off_x+max(0,-off_x);\n"
" int sy=off_y+max(0,-off_y);\n"
" int ex=off_x+min(s.kernel_width,s.input_width-off_x);\n"
" int ey=off_y+min(s.kernel_height,s.input_height-off_y);\n"
" \n"
" FLOAT4 result=0;\n"
" auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n"
" for (int y=sy; y<ey; y++) {\n"
" for (int x=sx; x<ex; x++) {\n"
" result += FLOAT4(z_in[y*s.input_width+x]);\n"
" }\n"
" }\n"
" int count=(ey-sy)*(ex-sx);\n"
" FLOAT4 div=count>0 ? 1.f/count : 1;\n"
" out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=M4(result*div);\n"
"}\n"
;
const char* shader_MetalROIPooling_metal =
"struct ROI_shape {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_batch;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int batch;\n"
" float spatial_scale;\n"
"};\n"
"kernel void ROI_pooling(const device M4 *in [[buffer(0)]],\n"
" const device M *roi [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant ROI_shape &s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;\n"
" \n"
" int ob=gid.z % s.batch;\n"
" int iz=gid.z/s.batch;\n"
" \n"
" auto b_roi=roi+ob*5;\n"
" int ib=int(b_roi[0]);\n"
" int x1=round(float(b_roi[1])*s.spatial_scale);\n"
" int y1=round(float(b_roi[2])*s.spatial_scale);\n"
" int x2=round(float(b_roi[3])*s.spatial_scale);\n"
" int y2=round(float(b_roi[4])*s.spatial_scale);\n"
" \n"
" int roi_w=max(x2-x1+1,1);\n"
" int roi_h=max(y2-y1+1,1);\n"
" float bin_size_w=(float)roi_w/(float)s.output_width;\n"
" float bin_size_h=(float)roi_h/(float)s.output_height;\n"
" \n"
" int w_start=clamp(x1+(int)floor(gid.x*bin_size_w) ,0,s.input_width);\n"
" int w_end=clamp(x1+(int)ceil((gid.x+1)*bin_size_w),0,s.input_width);\n"
" int h_start=clamp(y1+(int)floor(gid.y*bin_size_h) ,0,s.input_height);\n"
" int h_end=clamp(y1+(int)ceil((gid.y+1)*bin_size_h),0,s.input_height);\n"
" \n"
" int is_empty=(h_end <= h_start) || (w_end <= w_start);\n"
" auto z_in=in+(ib+iz*s.input_batch)*s.input_size;\n"
" auto max4=is_empty ? 0 : z_in[h_start*s.input_width+w_start];\n"
" for (int y=h_start; y<h_end; y++) {\n"
" auto y_in=z_in+y*s.input_width;\n"
" for (int x=w_start; x<w_end; x++) {\n"
" max4=max(max4,y_in[x]);\n"
" }\n"
" }\n"
" out[int(gid.z)*s.output_size+int(gid.y)*s.output_width+int(gid.x)]=max4;\n"
"}\n"
;
const char* shader_MetalConvolution1x1_metal =
"#define CONV_UNROLL (4)\n"
"#define CONV_UNROLL_L (8)\n"
"struct conv1x1_constants {\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int output_channel;\n"
" int batch;\n"
" int block_size;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" \n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=*xy_in0;\n"
" auto in41=*(xy_in0+1);\n"
" auto in42=*(xy_in0+2);\n"
" auto in43=*(xy_in0+3);\n"
" auto w=xy_wt[z];\n"
" \n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" \n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z4_w8(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device MNN::char4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device float4 *dequantScale [[buffer(5)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
" for (int bi=0; bi<cst.block_size; ++bi) {\n"
" FLOAT4 bs0=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0]);\n"
" FLOAT4 bs1=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1]);\n"
" FLOAT4 scale=bs0;\n"
" FLOAT4 dequant_bias=bs1;\n"
" int zmin=bi*block;\n"
" int zmax=min(zmin+block,cst.input_slice);\n"
" for (int z=zmin; z<zmax; z++) {\n"
" auto in40=(FLOAT4)*xy_in0;\n"
" auto in41=(FLOAT4)*(xy_in0+1);\n"
" auto in42=(FLOAT4)*(xy_in0+2);\n"
" auto in43=(FLOAT4)*(xy_in0+3);\n"
" auto w=xy_wt[z];\n"
" FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n"
" FLOAT4x4 w_dequant;\n"
" for (int i=0; i<4; ++i) {\n"
" w_dequant[i]=w_fp32[i]*scale[i]+dequant_bias[i];\n"
" }\n"
" result0 += FLOAT4(in40*w_dequant);\n"
" result1 += FLOAT4(in41*w_dequant);\n"
" result2 += FLOAT4(in42*w_dequant);\n"
" result3 += FLOAT4(in43*w_dequant);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" }\n"
" /* true */ \n"
" xy_out[0]=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z4_w4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device float4 *dequantScale [[buffer(5)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
" for (int bi=0; bi<cst.block_size; ++bi) {\n"
" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0]);\n"
" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1]);\n"
" int zmin=bi*block;\n"
" int zmax=min(zmin+block,cst.input_slice);\n"
" for (int z=zmin; z<zmax; z++) {\n"
" auto in40=(FLOAT4)*xy_in0;\n"
" auto in41=(FLOAT4)*(xy_in0+1);\n"
" auto in42=(FLOAT4)*(xy_in0+2);\n"
" auto in43=(FLOAT4)*(xy_in0+3);\n"
" MNN::uchar4x2 w_int4=xy_wt[z];\n"
" // MNN::char4x4 w_int8(char4(0));\n"
" /* weight int4->float */\n"
" //FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n"
" FLOAT4x4 w_dequant;\n"
" for (int i=0; i<4; ++i) {\n"
" // M4 w4=M4(w_fp32[i]);\n"
" FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n"
" FLOAT4 res=w4*scale[i]+dequant_bias[i];\n"
" w_dequant[i]=res;\n"
" }\n"
" result0 += FLOAT4(in40*w_dequant);\n"
" result1 += FLOAT4(in41*w_dequant);\n"
" result2 += FLOAT4(in42*w_dequant);\n"
" result3 += FLOAT4(in43*w_dequant);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" }\n"
" \n"
" /* true */ \n"
" xy_out[0]=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
" // MNN::uchar4x2 w=xy_wt[0];\n"
" // xy_out[0]=M4(w[0][0],w[0][1],w[0][0],w[0][1]);\n"
" // xy_out[0]=M4((float)(w[0][0]>>4)-8,(float)(w[0][0] >> 4),(float)(w[0][0] & 15)-8,(float)(w[0][0] & 15));\n"
" \n"
" \n"
" /* true */ \n"
" //xy_out[0]=activate(M4(result0),cst.activation);\n"
" //if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" //if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" //if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z4_m1w4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device float4 *dequantScale [[buffer(5)]],\n"
" uint3 gid[[threadgroup_position_in_grid]],\n"
" uint tiisg[[thread_index_in_simdgroup]],\n"
" uint sgitg[[simdgroup_index_in_threadgroup]]) {\n"
" int uz=gid.x*2+sgitg;\n"
" int rx=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=FLOAT4(0);\n"
" int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
" for (int bi=0; bi<cst.block_size; bi++) {\n"
" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0]);\n"
" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1]);\n"
" int zmin=bi*block;\n"
" int zmax=min(zmin+block,cst.input_slice);\n"
" for (int z=zmin+tiisg; z<zmax; z+=SIMD_GROUP_WIDTH) {\n"
" auto in40=(FLOAT4)*(xy_in0+z*cst.input_size*cst.batch);\n"
" MNN::uchar4x2 w_int4=xy_wt[z];\n"
" FLOAT4x4 w_dequant;\n"
" for (int i=0; i<4; ++i) {\n"
" FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n"
" FLOAT4 res=w4*scale[i]+dequant_bias[i];\n"
" w_dequant[i]=res;\n"
" }\n"
" result0 += FLOAT4(in40*w_dequant);\n"
" \n"
"// FLOAT4x4 w_dequant;\n"
"// for (int i=0; i<4; ++i) {\n"
"// FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n"
"// FLOAT4 res=w4*scale[i]+dequant_bias[i];\n"
"// w_dequant[i]=w4;\n"
"// }\n"
"//\n"
"// FLOAT4 temp=FLOAT4(in40*w_dequant);\n"
"// result0 += temp*scale+(in40.x+in40.y+in40.z+in40.w)*dequant_bias;\n"
" }\n"
" }\n"
" FLOAT4 res;\n"
" res.x=simd_sum(result0.x);\n"
" res.y=simd_sum(result0.y);\n"
" res.z=simd_sum(result0.z);\n"
" res.w=simd_sum(result0.w);\n"
" /* true */\n"
" if (tiisg == 0) {\n"
" xy_out[0]=activate(M4(res+biasValue),cst.activation);\n"
" }\n"
"}\n"
"kernel void conv1x1_g1z8(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL_L >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL_L;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.batch*cst.output_size+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL_L);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in42=xy_in0[2];\n"
" auto in43=xy_in0[3];\n"
" auto in44=xy_in0[4];\n"
" auto in45=xy_in0[5];\n"
" auto in46=xy_in0[6];\n"
" auto in47=xy_in0[7];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" result6 += FLOAT4(in46*w);\n"
" result7 += FLOAT4(in47*w);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
" if (computeSize>4) {xy_out[4]=activate(M4(result4),cst.activation); }\n"
" if (computeSize>5) {xy_out[5]=activate(M4(result5),cst.activation); }\n"
" if (computeSize>6) {xy_out[6]=activate(M4(result6),cst.activation); }\n"
" if (computeSize>7) {xy_out[7]=activate(M4(result7),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w4h4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*16 >= cst.output_width || (int)gid.y >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 4;\n"
" int idx_h=0;\n"
" int idx_c=gid.y/cst.batch;\n"
" int idx_b=gid.y % cst.batch;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result00=biasValue,result01=biasValue,result02=biasValue,result03=biasValue;\n"
" FLOAT4 result10=biasValue,result11=biasValue,result12=biasValue,result13=biasValue;\n"
" FLOAT4 result20=biasValue,result21=biasValue,result22=biasValue,result23=biasValue;\n"
" FLOAT4 result30=biasValue,result31=biasValue,result32=biasValue,result33=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in00=xy_in0[0];\n"
" auto in01=xy_in0[1];\n"
" auto in02=xy_in0[2];\n"
" auto in03=xy_in0[3];\n"
" auto in10=xy_in0[4];\n"
" auto in11=xy_in0[5];\n"
" auto in12=xy_in0[6];\n"
" auto in13=xy_in0[7];\n"
" \n"
" auto in20=xy_in0[8];\n"
" auto in21=xy_in0[9];\n"
" auto in22=xy_in0[10];\n"
" auto in23=xy_in0[11];\n"
" auto in30=xy_in0[12];\n"
" auto in31=xy_in0[13];\n"
" auto in32=xy_in0[14];\n"
" auto in33=xy_in0[15];\n"
" auto w=xy_wt[z];\n"
" result00 += FLOAT4(in00*w);\n"
" result01 += FLOAT4(in01*w);\n"
" result02 += FLOAT4(in02*w);\n"
" result03 += FLOAT4(in03*w);\n"
" result10 += FLOAT4(in10*w);\n"
" result11 += FLOAT4(in11*w);\n"
" result12 += FLOAT4(in12*w);\n"
" result13 += FLOAT4(in13*w);\n"
" \n"
" result20 += FLOAT4(in20*w);\n"
" result21 += FLOAT4(in21*w);\n"
" result22 += FLOAT4(in22*w);\n"
" result23 += FLOAT4(in23*w);\n"
" result30 += FLOAT4(in30*w);\n"
" result31 += FLOAT4(in31*w);\n"
" result32 += FLOAT4(in32*w);\n"
" result33 += FLOAT4(in33*w);\n"
" \n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,16);\n"
" /* true */ *xy_out=activate(M4(result00),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result01),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result02),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result03),cst.activation); }\n"
" if (widthSize>4) {xy_out[4]=activate(M4(result10),cst.activation); }\n"
" if (widthSize>5) {xy_out[5]=activate(M4(result11),cst.activation); }\n"
" if (widthSize>6) {xy_out[6]=activate(M4(result12),cst.activation); }\n"
" if (widthSize>7) {xy_out[7]=activate(M4(result13),cst.activation); }\n"
" if (widthSize>8) {xy_out[8]=activate(M4(result20),cst.activation); }\n"
" if (widthSize>9) {xy_out[9]=activate(M4(result21),cst.activation); }\n"
" if (widthSize>10) {xy_out[10]=activate(M4(result22),cst.activation); }\n"
" if (widthSize>11) {xy_out[11]=activate(M4(result23),cst.activation); }\n"
" if (widthSize>12) {xy_out[12]=activate(M4(result30),cst.activation); }\n"
" if (widthSize>13) {xy_out[13]=activate(M4(result31),cst.activation); }\n"
" if (widthSize>14) {xy_out[14]=activate(M4(result32),cst.activation); }\n"
" if (widthSize>15) {xy_out[15]=activate(M4(result33),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w2c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=0;\n"
" int idx_c=(gid.y % channel_pack) << 1;\n"
" int idx_b=gid.y/channel_pack;\n"
" \n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
" FLOAT4 result4=biasValue1,result5=biasValue1;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in40*w1);\n"
" result5 += FLOAT4(in41*w1);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ {xy_out[cst.output_size*cst.batch +0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result5),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w4c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=0;\n"
" int idx_c=(gid.y % channel_pack) << 1;\n"
" int idx_b=gid.y/channel_pack;\n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
" FLOAT4 result4=biasValue0,result5=biasValue0;\n"
" FLOAT4 result2=biasValue1,result3=biasValue1;\n"
" FLOAT4 result6=biasValue1,result7=biasValue1;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in44=xy_in0[2];\n"
" auto in45=xy_in0[3];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in44*w0);\n"
" result5 += FLOAT4(in45*w0);\n"
" result2 += FLOAT4(in40*w1);\n"
" result3 += FLOAT4(in41*w1);\n"
" result6 += FLOAT4(in44*w1);\n"
" result7 += FLOAT4(in45*w1);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result5),cst.activation); }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ xy_out[cst.output_size*cst.batch]=activate(M4(result2),cst.activation);\n"
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result3),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_size*cst.batch +2]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_size*cst.batch +3]=activate(M4(result7),cst.activation); }\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionGEMM_metal =
"struct matmul4x4_const {\n"
" int output_width;\n"
" int output_height;\n"
" int multi_length;\n"
" int group;\n"
"};\n"
"template <typename IType,typename OType>\n"
"static inline void matmul4x4_template(const device IType *in,\n"
" device OType *out,\n"
" const device IType *kt,\n"
" constant matmul4x4_const &cst,\n"
" uint3 gid) {\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height) {\n"
" auto ky=(int)gid.y+(int)gid.z*cst.output_height;\n"
" auto iy=(int)gid.x+(int)gid.z*cst.output_width;\n"
" auto off_in=in+iy*cst.multi_length;\n"
" auto off_wt=kt+ky*cst.multi_length;\n"
" auto off_out=out+iy+4*(int)gid.y*cst.output_width*cst.group;\n"
" \n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (int k=0; k<cst.multi_length; ++k) {\n"
" auto w4x4=off_wt[k];\n"
" auto i4x4=off_in[k];\n"
" result0 += FLOAT4(w4x4*i4x4[0]);\n"
" result1 += FLOAT4(w4x4*i4x4[1]);\n"
" result2 += FLOAT4(w4x4*i4x4[2]);\n"
" result3 += FLOAT4(w4x4*i4x4[3]);\n"
" }\n"
" *off_out=OType(result0); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result1); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result2); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result3);\n"
" }\n"
"}\n"
"kernel void matmul4x4(const device M4x4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" const device M4x4 *kt [[buffer(2)]],\n"
" constant matmul4x4_const &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" matmul4x4_template<M4x4,M4>(in,out,kt,cst,gid);\n"
"}\n"
;
const char* shader_MetalResize_metal =
"struct resize_shape {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int sliceMap;\n"
"};\n"
"kernel void resize_nearest(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" \n"
" float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n"
" int left=floor(srcX);\n"
" int top=floor(srcY);\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" auto in_top=in_z+top*c.input_width;\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=in_top[left];\n"
"}\n"
"kernel void resize_bilinear(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" \n"
" float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n"
" int srcXInt=int(floor(srcX));\n"
" int srcYInt=int(floor(srcY));\n"
" int left=clamp(srcXInt,0,c.input_width-1);\n"
" int right=clamp(srcXInt+1,0,c.input_width-1);\n"
" int top=clamp(srcYInt,0,c.input_height-1);\n"
" int bottom=clamp(srcYInt+1,0,c.input_height-1);\n"
" float x2_factor=srcX-float(srcXInt);\n"
" float y2_factor=srcY-float(srcYInt);\n"
" float x1_factor=1-x2_factor;\n"
" float y1_factor=1-y2_factor;\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" auto in_top=in_z+top*c.input_width;\n"
" auto in_bottom=in_z+bottom*c.input_width;\n"
" auto tl=float4(in_top[left])*x1_factor*y1_factor;\n"
" auto tr=float4(in_top[right])*x2_factor*y1_factor;\n"
" auto bl=float4(in_bottom[left])*x1_factor*y2_factor;\n"
" auto br=float4(in_bottom[right])*x2_factor*y2_factor;\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=M4(tl+tr+bl+br);\n"
"}\n"
"static inline float4 resize_cubic_interpolation(float4 A,float4 B,float4 C,float4 D,float factor) {\n"
" float4 a=(B-C)+0.5f*(B-A)+(D-C)*0.5f;\n"
" float4 b=C-((B-A)+(B-C))-(B+D)*0.5f;\n"
" float4 c=(C-A)*0.5f;\n"
" float4 d=B;\n"
" return ((a*factor+b)*factor+c)*factor+d;\n"
"}\n"
"kernel void resize_cubic(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" float x=gid.x*s.x+s.y,y=gid.y*s.z+s.w;\n"
" \n"
" float x_factor=x-floor(x);\n"
" float y_factor=y-floor(y);\n"
" \n"
" int4 xp=int4(int(x)-1,int(x)+0,int(x)+1,int(x)+2);\n"
" xp=clamp(xp,0,c.input_width-1);\n"
" \n"
" int4 yp=int4(int(y)-1,int(y)+0,int(y)+1,int(y)+2);\n"
" yp=clamp(yp,0,c.input_height-1);\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" float4x4 ABCD;\n"
" for (int i=0; i<4; i++) {\n"
" auto in_y=in_z+yp[i]*c.input_width;\n"
" float4 A=float4(in_y[xp[0]]);\n"
" float4 B=float4(in_y[xp[1]]);\n"
" float4 C=float4(in_y[xp[2]]);\n"
" float4 D=float4(in_y[xp[3]]);\n"
" ABCD[i]=resize_cubic_interpolation(A,B,C,D,x_factor);\n"
" }\n"
" \n"
" auto val=M4(resize_cubic_interpolation(ABCD[0],ABCD[1],ABCD[2],ABCD[3],y_factor));\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=val;\n"
"}\n"
;
const char* shader_MetalPReLU_metal =
"struct prelu_shape {\n"
" int size;\n"
" int slice;\n"
" int batch;\n"
"};\n"
"kernel void prelu(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto v4=in[int(gid)];\n"
" out[int(gid)]=select(v4,M4(slope)*v4,signbit(v4));\n"
"}\n"
"kernel void prelu_slopes(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" const device float4 *slope [[buffer(2)]],\n"
" constant prelu_shape& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) { // size,slice,batch\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;\n"
" \n"
" int z=gid.z+gid.y*s.batch;\n"
" auto v4=in[z*s.size+int(gid.x)];\n"
" out[z*s.size+int(gid.x)]=select(v4,M4(slope[int(gid.y)])*v4,signbit(v4));\n"
"}\n"
;
const char* shader_MetalDefine_metal =
"using namespace metal;\n"
"// \n"
"// Macro\n"
"// \n"
"#define SIMD_GROUP_WIDTH 32 // setting SIMD group size is 32\n"
"#define UP_DIV(x,y) ( ((x)+(y)-1)/(y) )\n"
"#define ROUND_UP(x,y) ( ((x)+(y)-1)/(y)*(y) )\n"
"// whether computer with float32 when store with float16\n"
"#define MNN_METAL_FLOAT32_COMPUTER 1 //\n"
"#if MNN_METAL_FULL_PRECISION\n"
"typedef float M;\n"
"typedef float2 M2;\n"
"typedef float3 M3;\n"
"typedef float4 M4;\n"
"typedef float2x2 M2x2;\n"
"typedef float2x3 M2x3;\n"
"typedef float2x4 M2x4;\n"
"typedef float3x2 M3x2;\n"
"typedef float3x3 M3x3;\n"
"typedef float3x4 M3x4;\n"
"typedef float4x2 M4x2;\n"
"typedef float4x3 M4x3;\n"
"typedef float4x4 M4x4;\n"
"#else\n"
"typedef half M;\n"
"typedef half2 M2;\n"
"typedef half3 M3;\n"
"typedef half4 M4;\n"
"typedef half2x2 M2x2;\n"
"typedef half2x3 M2x3;\n"
"typedef half2x4 M2x4;\n"
"typedef half3x2 M3x2;\n"
"typedef half3x3 M3x3;\n"
"typedef half3x4 M3x4;\n"
"typedef half4x2 M4x2;\n"
"typedef half4x3 M4x3;\n"
"typedef half4x4 M4x4;\n"
"#endif\n"
"#if MNN_METAL_FLOAT32_COMPUTER\n"
"typedef float FLOAT;\n"
"typedef float2 FLOAT2;\n"
"typedef float3 FLOAT3;\n"
"typedef float4 FLOAT4;\n"
"typedef float2x2 FLOAT2x2;\n"
"typedef float2x3 FLOAT2x3;\n"
"typedef float2x4 FLOAT2x4;\n"
"typedef float3x2 FLOAT3x2;\n"
"typedef float3x3 FLOAT3x3;\n"
"typedef float3x4 FLOAT3x4;\n"
"typedef float4x2 FLOAT4x2;\n"
"typedef float4x3 FLOAT4x3;\n"
"typedef float4x4 FLOAT4x4;\n"
"#else\n"
"typedef half FLOAT;\n"
"typedef half2 FLOAT2;\n"
"typedef half3 FLOAT3;\n"
"typedef half4 FLOAT4;\n"
"typedef half2x2 FLOAT2x2;\n"
"typedef half2x3 FLOAT2x3;\n"
"typedef half2x4 FLOAT2x4;\n"
"typedef half3x2 FLOAT3x2;\n"
"typedef half3x3 FLOAT3x3;\n"
"typedef half3x4 FLOAT3x4;\n"
"typedef half4x2 FLOAT4x2;\n"
"typedef half4x3 FLOAT4x3;\n"
"typedef half4x4 FLOAT4x4;\n"
"#endif\n"
"namespace MNN {\n"
" \n"
" // \n"
" // Number Limit\n"
" // \n"
"#define INT8_MAX 127\n"
"#define INT8_MIN -128\n"
"#define INT16_MAX 32767\n"
"#define INT16_MIN -32768\n"
"#define INT32_MAX 2147483647\n"
"#define INT32_MIN -2147483648\n"
"#define UINT8_MAX 255\n"
"#define UINT16_MAX 65535\n"
"#define UINT32_MAX 4294967295U\n"
" \n"
" template<typename T> struct num_limits {\n"
" static int max() { return 0; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<char> {\n"
" static int max() { return INT8_MAX; };\n"
" static int min() { return INT8_MIN; };\n"
" };\n"
" template<> struct num_limits<uchar> {\n"
" static int max() { return UINT8_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<short> {\n"
" static int max() { return INT16_MAX; };\n"
" static int min() { return INT16_MIN; };\n"
" };\n"
" template<> struct num_limits<ushort> {\n"
" static int max() { return UINT16_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<int> {\n"
" static int max() { return INT32_MAX; };\n"
" static int min() { return INT32_MIN; };\n"
" };\n"
" template<> struct num_limits<uint> {\n"
" static int max() { return UINT32_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" \n"
" // \n"
" // Function\n"
" // \n"
" inline int dot(int4 i4,int4 w4) {\n"
" return i4[0]*w4[0]+i4[1]*w4[1]+i4[2]*w4[2]+i4[3]*w4[3];\n"
" }\n"
" \n"
" template <typename T>\n"
" inline T saturate_round_x2_high_mul(T a,int b) {\n"
" return mulhi(a,b)*2;\n"
" }\n"
" \n"
" template <typename T>\n"
" inline T round_divide_by_pot(T x,int exponent) {\n"
" int mask=(1 << exponent)-1;\n"
" T remainder=x & mask;\n"
" T threshold=(mask >> 1)+T(x<0);\n"
" return (x >> exponent)+T(remainder>threshold);\n"
" }\n"
" \n"
" // \n"
" // Typedef\n"
" // \n"
" \n"
" typedef struct short4x4 {\n"
" private:\n"
" short4 v[4];\n"
" public:\n"
" short4x4(short4 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" short4x4(short4 a,short4 b,short4 c,short4 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread short4& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device short4& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup short4& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread short4& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device short4& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup short4& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x4() const {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const device{\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const threadgroup {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x4() const {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const device {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const threadgroup {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" } short4x4;\n"
" \n"
" typedef struct char4x4 {\n"
" private:\n"
" char4 v[4];\n"
" public:\n"
" char4x4(char4 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" char4x4(char4 a,char4 b,char4 c,char4 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread char4& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device char4& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup char4& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread char4& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device char4& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup char4& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x4() const {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const device {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const threadgroup {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x4() const {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const device {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const threadgroup {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" } char4x4;\n"
" typedef struct char4x2 {\n"
" private:\n"
" char2 v[4];\n"
" public:\n"
" char4x2(char2 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" char4x2(char2 a,char2 b,char2 c,char2 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread char2& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device char2& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup char2& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread char2& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device char2& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup char2& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x2() const {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const device {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const threadgroup {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x2() const {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const device {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const threadgroup {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" } char4x2;\n"
" typedef struct uchar4x2 {\n"
" private:\n"
" uchar2 v[4];\n"
" public:\n"
" uchar4x2(uchar2 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" uchar4x2(uchar2 a,uchar2 b,uchar2 c,uchar2 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread uchar2& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device uchar2& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup uchar2& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread uchar2& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device uchar2& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup uchar2& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x2() const {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const device {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const threadgroup {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x2() const {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const device {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const threadgroup {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" } uchar4x2;\n"
"}\n"
;
const char* shader_MetalEltwise_metal =
"kernel void eltwise_prod(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=in0[(int)gid]*in1[(int)gid];\n"
" }\n"
"}\n"
"kernel void eltwise_max(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=max(in0[(int)gid],in1[(int)gid]);\n"
" }\n"
"}\n"
"kernel void eltwise_add(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=in0[(int)gid]+in1[(int)gid];\n"
" }\n"
"}\n"
;