MNN/source/backend/vulkan/execution/glsl/gemm16x16Half.comp

61 lines
2.1 KiB
Plaintext

#version 450 core
#extension GL_KHX_shader_explicit_arithmetic_types: enable
#extension GL_KHX_shader_explicit_arithmetic_types_float16: enable
layout(std140) buffer;
layout(set=0, binding=0, rgba16f) writeonly restrict mediump uniform image2D uOutput;
layout(set=0, binding=1) mediump uniform sampler2D uInput;
layout(set=0, binding=2) mediump uniform sampler2D uKernel;
layout(set=0, binding=3) readonly restrict uniform constBuffer {
ivec4 outputSize;
int multiLength;
}uConst;
layout (local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
if (pos.x < uConst.outputSize.x && pos.y < uConst.outputSize.y)
{
f16vec4 o0 = f16vec4(0);
f16vec4 o1 = f16vec4(0);
f16vec4 o2 = f16vec4(0);
f16vec4 o3 = f16vec4(0);
int kenerlY = pos.y + pos.z * uConst.outputSize.y;
int srcY = pos.x + pos.z * uConst.outputSize.x;
for (int k=0; k<uConst.multiLength; ++k)
{
int x0 = 4*k+0;
int x1 = 4*k+1;
int x2 = 4*k+2;
int x3 = 4*k+3;
f16vec4 k0 = f16vec4(texelFetch(uKernel, ivec2(x0, kenerlY), 0));
f16vec4 k1 = f16vec4(texelFetch(uKernel, ivec2(x1, kenerlY), 0));
f16vec4 k2 = f16vec4(texelFetch(uKernel, ivec2(x2, kenerlY), 0));
f16vec4 k3 = f16vec4(texelFetch(uKernel, ivec2(x3, kenerlY), 0));
f16vec4 s0 = f16vec4(texelFetch(uInput, ivec2(x0, srcY), 0));
f16vec4 s1 = f16vec4(texelFetch(uInput, ivec2(x1, srcY), 0));
f16vec4 s2 = f16vec4(texelFetch(uInput, ivec2(x2, srcY), 0));
f16vec4 s3 = f16vec4(texelFetch(uInput, ivec2(x3, srcY), 0));
f16mat4 K = f16mat4(k0, k1, k2, k3);
o0 += K * s0;
o1 += K * s1;
o2 += K * s2;
o3 += K * s3;
}
imageStore(uOutput, ivec2(srcY, 4*pos.y+0), o0);
imageStore(uOutput, ivec2(srcY, 4*pos.y+1), o1);
imageStore(uOutput, ivec2(srcY, 4*pos.y+2), o2);
imageStore(uOutput, ivec2(srcY, 4*pos.y+3), o3);
}
}