MNN/source/backend/vulkan/execution/glsl/permute.comp

41 lines
1.1 KiB
Plaintext

#version 310 es
#define PRECISION mediump
precision PRECISION float;
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, binding = 0) readonly buffer srcBuffer{
float data[];
}uInput;
layout(set = 0, binding = 1) writeonly buffer dstBuffer{
float data[];
}uOutput;
layout(set = 0, binding = 2) uniform constBuffer{
ivec4 dims;
ivec4 inImSize; // width, height, channel, batch
ivec4 outImSize;
}uPermuteParam;
layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
void main()
{
ivec3 pos = ivec3(gl_GlobalInvocationID);
ivec3 inImgSize = ivec3(uPermuteParam.inImSize.xyz);
ivec3 outImgSize = ivec3(uPermuteParam.outImSize.xyz);
// input, output all are NCHW layout
ivec4 dimParam = uPermuteParam.dims.xyzw;
if(pos.x < outImgSize.x && pos.y < outImgSize.y)
{
int dimIndex[4];
dimIndex[dimParam.y] = pos.z;
dimIndex[dimParam.z] = pos.y;
dimIndex[dimParam.w] = pos.x;
int inputIndex = dimIndex[1] * inImgSize.x * inImgSize.y + dimIndex[2] * inImgSize.x + dimIndex[3];
int outputIndex = pos.x + pos.y * outImgSize.x + pos.z * outImgSize.x * outImgSize.y;
uOutput.data[outputIndex] = uInput.data[inputIndex];
}
}