MNN/source/backend/vulkan/execution/glsl/binaryBroadcast.comp

88 lines
2.2 KiB
Plaintext

#version 440 core
layout(set=0, binding=0) writeonly buffer sourceBuffer{
float data[];
} uOutput;
layout(set=0, binding=1) readonly buffer s0Buffer{
float data[];
}uInput0;
layout(set=0, binding=2) readonly buffer s1Buffer{
float data[];
}uInput1;
layout(set=0, binding=3) uniform constBuffer{
ivec4 stride00;
ivec4 stride01; // Last set limit
ivec4 stride10;
ivec4 stride11;
ivec4 stride20;
ivec4 stride21;
} uConstant;
layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main()
{
ivec3 posTmp = ivec3(gl_GlobalInvocationID);
if (posTmp.x < uConstant.stride01.w) {
int pos0 = 0;
int pos1 = 0;
int ind = 0;
int cur = posTmp.x;
ind = cur / uConstant.stride00.x;
cur = cur % uConstant.stride00.x;
pos0 += ind * uConstant.stride10.x;
pos1 += ind * uConstant.stride20.x;
ind = cur / uConstant.stride00.y;
cur = cur % uConstant.stride00.y;
pos0 += ind * uConstant.stride10.y;
pos1 += ind * uConstant.stride20.y;
ind = cur / uConstant.stride00.z;
cur = cur % uConstant.stride00.z;
pos0 += ind * uConstant.stride10.z;
pos1 += ind * uConstant.stride20.z;
ind = cur / uConstant.stride00.w;
cur = cur % uConstant.stride00.w;
pos0 += ind * uConstant.stride10.w;
pos1 += ind * uConstant.stride20.w;
ind = cur / uConstant.stride01.x;
cur = cur % uConstant.stride01.x;
pos0 += ind * uConstant.stride11.x;
pos1 += ind * uConstant.stride21.x;
ind = cur / uConstant.stride01.y;
cur = cur % uConstant.stride01.y;
pos0 += ind * uConstant.stride11.y;
pos1 += ind * uConstant.stride21.y;
float x0 = uInput0.data[pos0];
float x1 = uInput1.data[pos1];
float value = x0;
#ifdef ADD
value = x0 + x1;
#endif
#ifdef SUB
value = x0 - x1;
#endif
#ifdef MUL
value = x0 * x1;
#endif
#ifdef DIV
value = x0 / x1;
#endif
#ifdef VMAX
value = max(x0, x1);
#endif
#ifdef VMIN
value = min(x0, x1);
#endif
uOutput.data[posTmp.x] = value;
}
}