MNN/source/backend/vulkan/execution/glsl/convolution.comp

75 lines
2.3 KiB
Plaintext
Raw Normal View History

2019-04-17 10:49:11 +08:00
#version 440 core
layout(std430) buffer;
layout(set=0, binding=0, rgba16f) writeonly restrict mediump uniform image3D uOutput;
layout(set=0, binding=1) uniform mediump sampler3D uInput;
layout(set=0, binding=2) uniform mediump sampler3D uKernel;
layout(set=0, binding=3) uniform mediump sampler2D uBias;
layout(set=0, binding=4) uniform constBuffer {
ivec2 pad;
ivec2 kernelSize;
ivec2 stride;
ivec2 dilate;
ivec4 inputSize;
ivec4 outputSize;
int batch;
int group;
} uConstant;
#define UP_DIV(x, y) (((x)+(y)-1)/(y))
layout (local_size_x = 2, local_size_y = 2, local_size_z = 16) in;
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
ivec3 outputSize = uConstant.outputSize.xyz;
int oz = pos.z % uConstant.outputSize.z;
int ob = pos.z / uConstant.outputSize.z;
if (all(lessThan(pos.xy, outputSize.xy)) && ob < uConstant.outputSize.w)
{
ivec3 inputSize = uConstant.inputSize.xyz;
ivec2 s0 = pos.xy*uConstant.stride-uConstant.pad;
int fx, fy, fz;
vec4 color = vec4(0);
int startZ = ob*inputSize.z;
for (fz=0; fz<uConstant.inputSize.z; ++fz)
{
for (fy=0; fy<uConstant.kernelSize.y; ++fy)
{
int sy = fy*uConstant.dilate.y + s0.y;
for (fx=0; fx<uConstant.kernelSize.x; ++fx)
{
int sx = fx*uConstant.dilate.x + s0.x;
vec4 inputValue = texelFetch(uInput, ivec3(sx, sy, fz+startZ), 0);
vec4 k0 = texelFetch(uKernel, ivec3(4*fz+0, oz, fx+fy*uConstant.kernelSize.x), 0);
vec4 k1 = texelFetch(uKernel, ivec3(4*fz+1, oz, fx+fy*uConstant.kernelSize.x), 0);
vec4 k2 = texelFetch(uKernel, ivec3(4*fz+2, oz, fx+fy*uConstant.kernelSize.x), 0);
vec4 k3 = texelFetch(uKernel, ivec3(4*fz+3, oz, fx+fy*uConstant.kernelSize.x), 0);
color += k0*inputValue.x;
color += k1*inputValue.y;
color += k2*inputValue.z;
color += k3*inputValue.w;
}
}
}
color = color + texelFetch(uBias, ivec2(oz, 0), 0);
#ifdef RELU
color = max(color, vec4(0));
#endif
#ifdef RELU6
color = clamp(color, vec4(0), vec4(6));
#endif
imageStore(uOutput, pos, color);
}
}