mirror of https://github.com/alibaba/MNN.git
47 lines
1.1 KiB
Plaintext
47 lines
1.1 KiB
Plaintext
|
#version 440 core
|
||
|
|
||
|
layout(set = 0, binding = 0, std430) readonly buffer srcBuffer{
|
||
|
float data[];
|
||
|
}uInput;
|
||
|
|
||
|
layout(set = 0, binding = 1, rgba16f) writeonly restrict uniform image3D uOutput;
|
||
|
|
||
|
layout(set = 0, binding = 2) uniform constBuffer{
|
||
|
ivec4 info; // ow, oz, 0, 0
|
||
|
}uLSTM;
|
||
|
|
||
|
layout(local_size_x = 16) in;
|
||
|
|
||
|
void main()
|
||
|
{
|
||
|
ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||
|
if(pos.x >= uLSTM.info.x || pos.z >= uLSTM.info.y) return;
|
||
|
|
||
|
int ow = uLSTM.info.x;
|
||
|
int inputIndex = pos.x + pos.z * 4 * ow;
|
||
|
|
||
|
int lastZ = uLSTM.info.y / 4;
|
||
|
int lastChannel = uLSTM.info.y % 4;
|
||
|
|
||
|
vec4 temp = vec4(0.0);
|
||
|
if(lastZ == pos.z)
|
||
|
{
|
||
|
if(lastChannel == 1)
|
||
|
{
|
||
|
temp = vec4(uInput.data[inputIndex], 0, 0, 0);
|
||
|
}
|
||
|
else if(lastChannel == 2)
|
||
|
{
|
||
|
temp = vec4(uInput.data[inputIndex], uInput.data[inputIndex + ow], 0, 0);
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
temp = vec4(uInput.data[inputIndex], uInput.data[inputIndex + ow], uInput.data[inputIndex + 2*ow], 0);
|
||
|
}
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
temp = vec4(uInput.data[inputIndex], uInput.data[inputIndex + ow], uInput.data[inputIndex + 2*ow], uInput.data[inputIndex + 3*ow]);
|
||
|
}
|
||
|
imageStore(uOutput, pos, temp);
|
||
|
}
|