mirror of https://github.com/alibaba/MNN.git
88 lines
2.2 KiB
Plaintext
88 lines
2.2 KiB
Plaintext
#version 440 core
|
|
layout(set=0, binding=0) writeonly buffer sourceBuffer{
|
|
float data[];
|
|
} uOutput;
|
|
|
|
layout(set=0, binding=1) readonly buffer s0Buffer{
|
|
float data[];
|
|
}uInput0;
|
|
|
|
layout(set=0, binding=2) readonly buffer s1Buffer{
|
|
float data[];
|
|
}uInput1;
|
|
|
|
layout(set=0, binding=3) uniform constBuffer{
|
|
ivec4 stride00;
|
|
ivec4 stride01; // Last set limit
|
|
ivec4 stride10;
|
|
ivec4 stride11;
|
|
ivec4 stride20;
|
|
ivec4 stride21;
|
|
} uConstant;
|
|
|
|
layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
void main()
|
|
{
|
|
ivec3 posTmp = ivec3(gl_GlobalInvocationID);
|
|
if (posTmp.x < uConstant.stride01.w) {
|
|
int pos0 = 0;
|
|
int pos1 = 0;
|
|
int ind = 0;
|
|
int cur = posTmp.x;
|
|
|
|
ind = cur / uConstant.stride00.x;
|
|
cur = cur % uConstant.stride00.x;
|
|
pos0 += ind * uConstant.stride10.x;
|
|
pos1 += ind * uConstant.stride20.x;
|
|
|
|
ind = cur / uConstant.stride00.y;
|
|
cur = cur % uConstant.stride00.y;
|
|
pos0 += ind * uConstant.stride10.y;
|
|
pos1 += ind * uConstant.stride20.y;
|
|
|
|
ind = cur / uConstant.stride00.z;
|
|
cur = cur % uConstant.stride00.z;
|
|
pos0 += ind * uConstant.stride10.z;
|
|
pos1 += ind * uConstant.stride20.z;
|
|
|
|
ind = cur / uConstant.stride00.w;
|
|
cur = cur % uConstant.stride00.w;
|
|
pos0 += ind * uConstant.stride10.w;
|
|
pos1 += ind * uConstant.stride20.w;
|
|
|
|
ind = cur / uConstant.stride01.x;
|
|
cur = cur % uConstant.stride01.x;
|
|
pos0 += ind * uConstant.stride11.x;
|
|
pos1 += ind * uConstant.stride21.x;
|
|
|
|
ind = cur / uConstant.stride01.y;
|
|
cur = cur % uConstant.stride01.y;
|
|
pos0 += ind * uConstant.stride11.y;
|
|
pos1 += ind * uConstant.stride21.y;
|
|
|
|
float x0 = uInput0.data[pos0];
|
|
float x1 = uInput1.data[pos1];
|
|
float value = x0;
|
|
#ifdef ADD
|
|
value = x0 + x1;
|
|
#endif
|
|
#ifdef SUB
|
|
value = x0 - x1;
|
|
#endif
|
|
#ifdef MUL
|
|
value = x0 * x1;
|
|
#endif
|
|
#ifdef DIV
|
|
value = x0 / x1;
|
|
#endif
|
|
#ifdef VMAX
|
|
value = max(x0, x1);
|
|
#endif
|
|
#ifdef VMIN
|
|
value = min(x0, x1);
|
|
#endif
|
|
uOutput.data[posTmp.x] = value;
|
|
}
|
|
}
|