MNN/source/backend/vulkan/execution/glsl/binaryImage.comp

54 lines
1.2 KiB
Plaintext

#version 440 core
layout(set=0, binding=0, rgba16f) writeonly restrict uniform image3D uOutput;
layout(set=0, binding=1) uniform sampler3D uInput0;
layout(set=0, binding=2) uniform sampler3D uInput1;
layout(set=0, binding=3) uniform constBuffer{
ivec4 stride00;//WHC, LIMIT
ivec4 stride01;
ivec4 stride10;
ivec4 stride11;
ivec4 stride20;
ivec4 stride21;
} uConstant;
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main()
{
ivec3 posTmp = ivec3(gl_GlobalInvocationID);
ivec3 inSize = uConstant.stride00.xyz;
if(posTmp.x < uConstant.stride00.w)
{
ivec3 pos;
pos.z = posTmp.x / (inSize.x * inSize.y);
int subIndex = posTmp.x % (inSize.x * inSize.y);
pos.y = subIndex / inSize.x;
pos.x = subIndex % inSize.x;
vec4 x0 = texelFetch(uInput0, pos, 0);
vec4 x1 = texelFetch(uInput1, pos, 0);
vec4 value = x0;
#ifdef ADD
value = x0 + x1;
#endif
#ifdef SUB
value = x0 - x1;
#endif
#ifdef MUL
value = x0 * x1;
#endif
#ifdef DIV
value = x0 / x1;
#endif
#ifdef VMAX
value = max(x0, x1);
#endif
#ifdef VMIN
value = min(x0, x1);
#endif
imageStore(uOutput, pos, value);
}
}