MNN/source/backend/vulkan/execution/glsl/col2ImMali.comp

44 lines
1.2 KiB
Plaintext

#version 440 core
layout(std430) uniform;
layout(set=0, rgba16f, binding=0) readonly mediump uniform image2D uInput;
layout(set=0, rgba16f, binding=1) writeonly uniform image3D uOutput;
layout(set=0, rgba16f, binding=2) readonly uniform image2D uBias;
layout(set=0, binding=3) readonly uniform constBuffer {
ivec2 pad;
ivec2 kernelSize;
ivec2 stride;
ivec2 dilate;
ivec4 inputSize;
ivec4 outputSize;
int batch;
int group;
} uConstant;
layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
#define UP_DIV(x, y) (((x)+(y)-1)/(y))
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
ivec3 outputSize = uConstant.outputSize.xyz;
int oz = pos.z % outputSize.z;
int ob = pos.z / outputSize.z;
if (all(lessThan(pos.xy, outputSize.xy)))
{
int sourceXIndex = pos.x + pos.y*outputSize.x + ob*outputSize.x*outputSize.y;
int sourceX = sourceXIndex / 4;
int sourceY = 4*oz + sourceXIndex % 4;
vec4 color = imageLoad(uInput, ivec2(sourceX, sourceY)) + imageLoad(uBias, ivec2(oz, 0));
#ifdef RELU
color = max(color, vec4(0));
#endif
#ifdef RELU6
color = clamp(color, vec4(0), vec4(6));
#endif
imageStore(uOutput, pos, color);
}
}