mirror of https://github.com/alibaba/MNN.git
75 lines
2.3 KiB
Plaintext
75 lines
2.3 KiB
Plaintext
|
#version 440 core
|
||
|
|
||
|
layout(std430) buffer;
|
||
|
|
||
|
layout(set=0, binding=0, rgba16f) writeonly restrict mediump uniform image3D uOutput;
|
||
|
layout(set=0, binding=1) uniform mediump sampler3D uInput;
|
||
|
|
||
|
layout(set=0, binding=2) uniform mediump sampler3D uKernel;
|
||
|
|
||
|
layout(set=0, binding=3) uniform mediump sampler2D uBias;
|
||
|
|
||
|
layout(set=0, binding=4) uniform constBuffer {
|
||
|
ivec2 pad;
|
||
|
ivec2 kernelSize;
|
||
|
ivec2 stride;
|
||
|
ivec2 dilate;
|
||
|
ivec4 inputSize;
|
||
|
ivec4 outputSize;
|
||
|
int batch;
|
||
|
int group;
|
||
|
} uConstant;
|
||
|
|
||
|
#define UP_DIV(x, y) (((x)+(y)-1)/(y))
|
||
|
|
||
|
layout (local_size_x = 2, local_size_y = 2, local_size_z = 16) in;
|
||
|
|
||
|
void main()
|
||
|
{
|
||
|
ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||
|
ivec3 outputSize = uConstant.outputSize.xyz;
|
||
|
int oz = pos.z % uConstant.outputSize.z;
|
||
|
int ob = pos.z / uConstant.outputSize.z;
|
||
|
|
||
|
if (all(lessThan(pos.xy, outputSize.xy)) && ob < uConstant.outputSize.w)
|
||
|
{
|
||
|
ivec3 inputSize = uConstant.inputSize.xyz;
|
||
|
ivec2 s0 = pos.xy*uConstant.stride-uConstant.pad;
|
||
|
int fx, fy, fz;
|
||
|
vec4 color = vec4(0);
|
||
|
int startZ = ob*inputSize.z;
|
||
|
for (fz=0; fz<uConstant.inputSize.z; ++fz)
|
||
|
{
|
||
|
for (fy=0; fy<uConstant.kernelSize.y; ++fy)
|
||
|
{
|
||
|
int sy = fy*uConstant.dilate.y + s0.y;
|
||
|
for (fx=0; fx<uConstant.kernelSize.x; ++fx)
|
||
|
{
|
||
|
int sx = fx*uConstant.dilate.x + s0.x;
|
||
|
vec4 inputValue = texelFetch(uInput, ivec3(sx, sy, fz+startZ), 0);
|
||
|
|
||
|
vec4 k0 = texelFetch(uKernel, ivec3(4*fz+0, oz, fx+fy*uConstant.kernelSize.x), 0);
|
||
|
vec4 k1 = texelFetch(uKernel, ivec3(4*fz+1, oz, fx+fy*uConstant.kernelSize.x), 0);
|
||
|
vec4 k2 = texelFetch(uKernel, ivec3(4*fz+2, oz, fx+fy*uConstant.kernelSize.x), 0);
|
||
|
vec4 k3 = texelFetch(uKernel, ivec3(4*fz+3, oz, fx+fy*uConstant.kernelSize.x), 0);
|
||
|
|
||
|
color += k0*inputValue.x;
|
||
|
color += k1*inputValue.y;
|
||
|
color += k2*inputValue.z;
|
||
|
color += k3*inputValue.w;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
color = color + texelFetch(uBias, ivec2(oz, 0), 0);
|
||
|
#ifdef RELU
|
||
|
color = max(color, vec4(0));
|
||
|
#endif
|
||
|
#ifdef RELU6
|
||
|
color = clamp(color, vec4(0), vec4(6));
|
||
|
#endif
|
||
|
imageStore(uOutput, pos, color);
|
||
|
}
|
||
|
|
||
|
}
|