mirror of https://github.com/alibaba/MNN.git
171 lines
4.9 KiB
Plaintext
171 lines
4.9 KiB
Plaintext
#version 450 core
|
|
layout(std430) buffer;
|
|
|
|
layout(set=0, binding=0) writeonly buffer destBuffer{
|
|
float data[];
|
|
}uOutput;
|
|
|
|
layout(set=0, binding=1) readonly buffer sourceBuffer0{
|
|
float data[];
|
|
} uInput;
|
|
|
|
layout(set=0, binding=2) readonly buffer sourceBuffer1{
|
|
float data[];
|
|
} uGrid;
|
|
|
|
layout(set=0, binding=3) uniform gridSampleBuffer{
|
|
ivec4 inShape; // inW, inH
|
|
ivec4 outShape; // outW, outH
|
|
bool alignCorners;
|
|
}uGridSampleParam;
|
|
|
|
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
void indexCubeMap(vec3 d, out int face, out float s, out float t) {
|
|
vec3 absd;
|
|
float sc, tc, ma;
|
|
absd.x = abs(d.x);
|
|
absd.y = abs(d.y);
|
|
absd.z = abs(d.z);
|
|
face = -1;
|
|
if ((absd.x >= absd.y) && (absd.x >= absd.z)) {
|
|
if (d.x > 0.0f) {
|
|
face = 0;
|
|
sc = -d.z; tc = -d.y; ma = absd.x;
|
|
} else {
|
|
face = 1;
|
|
sc = d.z; tc = -d.y; ma = absd.x;
|
|
}
|
|
|
|
}
|
|
if ((absd.y >= absd.x) && (absd.y >= absd.z)) {
|
|
if (d.y > 0.0f) {
|
|
face = 2;
|
|
sc = d.x; tc = d.z; ma = absd.y;
|
|
} else {
|
|
face = 3;
|
|
sc = d.x; tc = -d.z; ma = absd.y;
|
|
}
|
|
}
|
|
if ((absd.z >= absd.x) && (absd.z >= absd.y)) {
|
|
if (d.z > 0.0f) {
|
|
face = 4;
|
|
sc = d.x; tc = -d.y; ma = absd.z;
|
|
} else {
|
|
face = 5;
|
|
sc = -d.x; tc = -d.y; ma = absd.z;
|
|
}
|
|
}
|
|
if (ma == 0.0f) {
|
|
s = 0.0f;
|
|
t = 0.0f;
|
|
face = -1;
|
|
} else {
|
|
s = ((sc / ma) + 1.0f) * 0.5f;
|
|
t = ((tc / ma) + 1.0f) * 0.5f;
|
|
}
|
|
}
|
|
|
|
float LoadSample(int positionX, int positionY, int c, int n) {
|
|
float value;
|
|
int width = uGridSampleParam.inShape.x;
|
|
int height = uGridSampleParam.inShape.y;
|
|
#ifdef PAD_MODE_ZEROS
|
|
if (positionX < 0 || positionX >= width || positionY < 0 || positionY >= height) {
|
|
value = 0.0;
|
|
} else {
|
|
value = uInput.data[0
|
|
+ positionX * uGridSampleParam.inShape.z
|
|
+ positionY * width * uGridSampleParam.inShape.z
|
|
+ n * width * height * uGridSampleParam.inShape.z
|
|
+ c
|
|
];
|
|
}
|
|
#else
|
|
positionX = clamp(positionX, 0, width - 1);
|
|
positionY = clamp(positionY, 0, height - 1);
|
|
value = uInput.data[0
|
|
+ positionX * uGridSampleParam.inShape.z
|
|
+ positionY * width * uGridSampleParam.inShape.z
|
|
+ n * width * height * uGridSampleParam.inShape.z
|
|
+ c
|
|
];
|
|
|
|
#endif
|
|
return value;
|
|
}
|
|
|
|
void main()
|
|
{
|
|
int pos = int(gl_GlobalInvocationID.x);
|
|
// input output grid layout is NC4HW4
|
|
|
|
ivec4 inputShape = uGridSampleParam.inShape;
|
|
ivec4 outputShape = uGridSampleParam.outShape;
|
|
int total = outputShape.x * outputShape.y * outputShape.z * outputShape.w;
|
|
|
|
if(pos < total)
|
|
{
|
|
// get nchw num of output
|
|
int x = pos % outputShape.x;
|
|
int tmp = pos / outputShape.x;
|
|
int y = tmp % outputShape.y;
|
|
tmp = tmp / outputShape.y;
|
|
int z = tmp % outputShape.z;
|
|
int on = tmp / outputShape.z;
|
|
|
|
// get position in grid
|
|
int gridPosition = on * outputShape.x * outputShape.y + y * outputShape.x + x;
|
|
float u = uGrid.data[inputShape.w * gridPosition + 0];
|
|
float v = uGrid.data[inputShape.w * gridPosition + 1];
|
|
float w = uGrid.data[inputShape.w * gridPosition + 2];
|
|
float gridX;
|
|
float gridY;
|
|
int face;
|
|
indexCubeMap(vec3(u, v, w), face, gridX, gridY);
|
|
float value = 0.0;
|
|
if (face >= 0) {
|
|
int n = on * 6 + face;
|
|
|
|
// compute position of input
|
|
#ifdef NEAREST
|
|
float cordH = (gridY) * (inputShape.y);
|
|
float cordW = (gridX) * (inputShape.x);
|
|
int positionX = int(floor(cordW));
|
|
int positionY = int(floor(cordH));
|
|
|
|
value = LoadSample(positionX, positionY, z, n);
|
|
#else
|
|
float cordH = (gridY) * (inputShape.y) - 0.5;
|
|
float cordW = (gridX) * (inputShape.x) - 0.5;
|
|
int w0_h = int(floor(cordH));
|
|
int w0_w = int(floor(cordW));
|
|
int w1_h = w0_h + 1;
|
|
int w1_w = w0_w + 1;
|
|
float oneV = float(1.0);
|
|
|
|
float i00 = LoadSample(w0_w, w0_h, z, n);
|
|
float i01 = LoadSample(w1_w, w0_h, z, n);
|
|
float i10 = LoadSample(w0_w, w1_h, z, n);
|
|
float i11 = LoadSample(w1_w, w1_h, z, n);
|
|
|
|
float f0 = float(float(w1_w) - cordW);
|
|
float f1 = oneV - f0;
|
|
float h0 = float(float(w1_h) - cordH);
|
|
float h1 = oneV - h0;
|
|
|
|
float i0 = i00 * f0 + i01 * f1;
|
|
float i1 = i10 * f0 + i11 * f1;
|
|
|
|
value = i0 * h0 + i1 * h1;
|
|
#endif
|
|
}
|
|
uOutput.data[0
|
|
+ x * outputShape.z
|
|
+ y * outputShape.x * outputShape.z
|
|
+ z
|
|
+ on * outputShape.x * outputShape.y * outputShape.z
|
|
] = value;
|
|
}
|
|
}
|