Merge pull request #1533 from jokerz0624/to_merge

GridSample arm82
This commit is contained in:
jxt1234 2021-06-23 20:50:50 +08:00 committed by GitHub
commit 1a6cacc808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 241 additions and 107 deletions

View File

@ -108,6 +108,95 @@ static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT
}
}
static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners) {
float16x8_t zero = vdupq_n_f16(0);
float16x8_t one = vdupq_n_f16(1);
float16x8_t half = vdupq_n_f16(0.5f);
float16x8_t a = alignCorners ? one : zero;
float16x8_t b = alignCorners ? zero : one;
float16x8_t inW_sub_a = vsubq_f16(vdupq_n_f16(inW), a);
float16x8_t inH_sub_a = vsubq_f16(vdupq_n_f16(inH), a);
int area = outH * outW;
int areaC8 = area / 8;
int areaRemain = area - areaC8 * 8;
for (int i = 0; i < areaC8; ++i) {
auto cordH = vld2q_f16(src);
// float16x8_t x = cordH.val[0];
// float16x8_t y = cordH.val[1];
cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b));
cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b));
vst2q_f16(dst, cordH);
src += 16;
dst += 16;
}
for (int i = 0; i < areaRemain; ++i) {
float16x8_t x = vdupq_n_f16(src[0]);
float16x8_t y = vdupq_n_f16(src[1]);
x = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, x), inW_sub_a), b));
y = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, y), inH_sub_a), b));
dst[0] = x[0];
dst[1] = y[0];
src += 2;
dst += 2;
}
}
static Vec MNNGridSampleLoadSampleFP16(int h, int w, const FLOAT16 *buffer, int height, int width, bool padMode) {
if (h < 0 || h >= height || w < 0 || w >= width) {
if(padMode == true) { //padMode == BorderMode_ZEROS
return (FLOAT16)0;
}
// Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
// For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
// the leftover reflections degrade to GridSamplePaddingMode_BORDER
h = h < 0 ? 0 : (h > (height - 1) ? (height - 1) : h);
w = w < 0 ? 0 : (w > (width - 1) ? (width - 1) : w);
}
return Vec::load(buffer + h * width * 8 + w * 8);
}
static void MNNGridSampleInterpFP16(FLOAT16* outputPtr, const FLOAT16* inputPtr, const FLOAT16* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode) {
for (auto ow = 0; ow < outW; ++ow) {
auto w_fp16 = cordPtr[2 * ow + 0];
auto h_fp16 = cordPtr[2 * ow + 1];
float w = (float)(w_fp16);
float h = (float)(h_fp16);
Vec interp;
if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
int nh = vcvtms_s32_f32(h + 0.5f);
int nw = vcvtms_s32_f32(w + 0.5f);
interp = MNNGridSampleLoadSampleFP16(nh, nw, inputPtr, inH, inW, padMode);
} else { //sampleMode == GridSampleMode_BILINEAR
int w0_h = vcvtms_s32_f32(h);
int w0_w = vcvtms_s32_f32(w);
int w1_h = vcvtps_s32_f32(h);
int w1_w = vcvtps_s32_f32(w);
auto oneV = Vec((FLOAT16)1);
Vec i00 = MNNGridSampleLoadSampleFP16(w0_h, w0_w, inputPtr, inH, inW, padMode);
Vec i01 = MNNGridSampleLoadSampleFP16(w0_h, w1_w, inputPtr, inH, inW, padMode);
Vec i10 = MNNGridSampleLoadSampleFP16(w1_h, w0_w, inputPtr, inH, inW, padMode);
Vec i11 = MNNGridSampleLoadSampleFP16(w1_h, w1_w, inputPtr, inH, inW, padMode);
auto f0 = Vec((FLOAT16)w1_w - w_fp16);
auto f1 = oneV - f0;
auto h0 = Vec((FLOAT16)w1_h - h_fp16);
auto h1 = oneV - h0;
Vec i0 = i00 * f0 + i01 * f1;
Vec i1 = i10 * f0 + i11 * f1;
interp = i0 * h0 + i1 * h1;
}
Vec::save(outputPtr + 8 * ow, interp);
}
}
static void MNNCopyC8WithStrideFP16(const FLOAT16* source, FLOAT16* dest, size_t srcStride, size_t dstStride, size_t count) {
using Vec = MNN::Math::Vec<FLOAT16, 8>;
for (int i = 0; i < count; ++i) {
@ -536,6 +625,8 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge);
gInstance->penalty = 2.0f;
FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16);
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16);
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp, MNNGridSampleInterpFP16);
FUNC_PTR_ASSIGN(gInstance->MNNCopyC4WithStride, MNNCopyC8WithStrideFP16);
FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);

View File

@ -26,30 +26,70 @@ void MNNNC8HW8TONC4HW4(float* dest, const FLOAT16* source, size_t plane, size_t
template <typename TIN, typename TOUT, int UNIT>
void MNNPackUNIT(TOUT* dst, const TIN* src, size_t area, size_t depth) {
int z, x;
int cur = 0;
memset(dst, 0, area * UP_DIV(depth, UNIT) * UNIT * sizeof(TOUT));
for (z = 0; z < depth; ++z) {
int plane = z / UNIT;
TOUT* dstPlane = plane * area * UNIT + dst;
int offset = z % UNIT;
for (x = 0; x < area; ++x) {
dstPlane[UNIT * x + offset] = TOUT(src[cur++]);
int depthCUnit = depth / UNIT;
int depthRemain = depthCUnit * UNIT;
int remain = depth - depthRemain;
int z, x, y;
const TIN* srcChannel[UNIT];
const TIN* srcOffset = src;
for(z = 0; z < depthCUnit; ++z) {
for(y = 0; y < UNIT; ++y) {
srcChannel[y] = srcOffset + area * y;
}
for(x = 0; x < area; ++x) {
for(y = 0; y < UNIT; ++y) {
dst[0] = TOUT(srcChannel[y][0]);
srcChannel[y]++;
dst++;
}
}
srcOffset += area * UNIT;
}
if(remain > 0){
for(y = 0; y < remain; ++y) {
srcChannel[y] = srcOffset + area * y;
}
for(x = 0; x < area; ++x) {
for(y = 0; y < remain; ++y) {
dst[0] = TOUT(srcChannel[y][0]);
srcChannel[y]++;
dst++;
}
for(y = remain; y < UNIT; ++y) {
dst[0] = 0;
dst++;
}
}
}
}
template <typename TIN, typename TOUT, int UNIT>
void MNNUnpackUNIT(TOUT* dst, const TIN* src, size_t area, size_t depth) {
int x;
int z;
int cur = 0;
for (z = 0; z < depth; ++z) {
int plane = z / UNIT;
const TIN* srcPlane = plane * area * UNIT + src;
int offset = z % UNIT;
for (x = 0; x < area; ++x) {
dst[cur++] = TOUT(srcPlane[UNIT * x + offset]);
int depthCUnit = depth / UNIT;
int depthRemain = depthCUnit * UNIT;
int remain = depth - depthRemain;
int z, x, y;
const TIN* srcChannel[UNIT];
const TIN* srcOffset = src;
for(z = 0; z < depthCUnit; ++z) {
for(y = 0; y < UNIT; ++y) {
srcChannel[y] = srcOffset + y;
for(x = 0; x < area; ++x) {
dst[0] = TOUT(srcChannel[y][0]);
srcChannel[y] += UNIT;
dst++;
}
}
srcOffset += area * UNIT;
}
if(remain > 0){
for(y = 0; y < remain; ++y) {
srcChannel[y] = srcOffset + y;
for(x = 0; x < area; ++x) {
dst[0] = TOUT(srcChannel[y][0]);
srcChannel[y] += UNIT;
dst++;
}
}
}
}

View File

@ -26,71 +26,13 @@ CPUGridSample::CPUGridSample(Backend *b, SampleMode mode, BorderMode paddingMode
mAlignCorners = alignCorners;
}
static float getPosition(float x, int range, bool alignCorners) {
float a = alignCorners ? 1.0f : 0.0f;
float b = alignCorners ? 0.0f : 1.0f;
return ((1 + x) * (range - a) - b) / 2.0f;
}
static int CLAMP(int v, int min, int max) {
if ((v) < min) {
(v) = min;
} else if ((v) > max) {
(v) = max;
}
return v;
}
static Vec4 sample(int h, int w, const float *buffer, int height, int width, BorderMode padMode) {
if (h < 0 || h >= height || w < 0 || w >= width) {
if(padMode == BorderMode_ZEROS) {
return 0.0f;
}
// Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
// For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
// the leftover reflections degrade to GridSamplePaddingMode_BORDER
h = CLAMP(h, 0, height - 1);
w = CLAMP(w, 0, width - 1);
}
return Vec4::load(buffer + h * width * 4 + w * 4);
}
static Vec4 interpolate(float h, float w, const float *buffer, int height, int width, SampleMode mode, BorderMode padMode) {
if (mode == SampleMode_NEAREST) {
int nh = ::floor(h+0.5f);
int nw = ::floor(w+0.5f);
return sample(nh, nw, buffer, height, width, padMode);
}
// mode == GridSampleMode_BILINEAR
int w0_h = ::floor(h);
int w0_w = ::floor(w);
int w1_h = ::ceil(h);
int w1_w = ::ceil(w);
auto oneV = Vec4(1.0f);
Vec4 i00 = sample(w0_h, w0_w, buffer, height, width, padMode);
Vec4 i01 = sample(w0_h, w1_w, buffer, height, width, padMode);
Vec4 i10 = sample(w1_h, w0_w, buffer, height, width, padMode);
Vec4 i11 = sample(w1_h, w1_w, buffer, height, width, padMode);
auto f0 = Vec4((float)w1_w - w);
auto f1 = oneV - f0;
auto h0 = Vec4((float)w1_h - h);
auto h1 = oneV - h0;
Vec4 i0 = i00 * f0 + i01 * f1;
Vec4 i1 = i10 * f0 + i11 * f1;
return i0 * h0 + i1 * h1;
}
ErrorCode CPUGridSample::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
int numberThread = static_cast<CPUBackend*>(backend())->threadNumber();
auto core = static_cast<CPUBackend*>(backend())->functions();
auto outputTensor = outputs[0];
auto outH = outputTensor->buffer().dim[2].extent;
auto outW = outputTensor->buffer().dim[3].extent;
mTempCordBuffer.reset(Tensor::createDevice<float>({1, outH * outW * 2}));
mTempCordBuffer.reset(Tensor::createDevice<uint8_t>({1, outH * outW * 2 * core->bytes}));
auto res = backend()->onAcquireBuffer(mTempCordBuffer.get(), Backend::DYNAMIC);
if (!res) {
return OUT_OF_MEMORY;
@ -103,49 +45,37 @@ ErrorCode CPUGridSample::onExecute(const std::vector<Tensor *> &inputs, const st
auto inputTensor = inputs[0];
auto gridTensor = inputs[1];
auto outputTensor = outputs[0];
float *inputPtr = inputTensor->host<float>();
float *gridPtr = gridTensor->host<float>();
auto *outputPtr = outputTensor->host<float>();
auto core = static_cast<CPUBackend*>(backend())->functions();
auto inputPtr = inputTensor->host<uint8_t>();
auto gridPtr = gridTensor->host<uint8_t>();
auto outputPtr = outputTensor->host<uint8_t>();
auto batches = inputTensor->buffer().dim[0].extent;
auto channels = inputTensor->buffer().dim[1].extent;
auto channelC4 = UP_DIV(channels, 4);
auto channelCUnit = UP_DIV(channels, core->pack);
auto inH = inputTensor->buffer().dim[2].extent;
auto inW = inputTensor->buffer().dim[3].extent;
auto outH = outputTensor->buffer().dim[2].extent;
auto outW = outputTensor->buffer().dim[3].extent;
auto cordPtr = mTempCordBuffer->host<float>();
auto cordPtr = mTempCordBuffer->host<uint8_t>();
auto threadCount = static_cast<CPUBackend*>(backend())->threadNumber();
auto tileCount = channelC4 * outH;
auto tileCount = channelCUnit * outH;
for (auto b = 0; b < batches; ++b) {
const float *_inputPtr = inputPtr + b * inputTensor->buffer().dim[0].stride;
const float *_gridPtr = gridPtr + b * gridTensor->buffer().dim[0].stride;
float *_outputPtr = outputPtr + b * outputTensor->buffer().dim[0].stride;
auto _inputPtr = inputPtr + b * inputTensor->buffer().dim[0].stride * core->bytes;
auto _gridPtr = gridPtr + b * gridTensor->buffer().dim[0].stride * core->bytes;
auto _outputPtr = outputPtr + b * outputTensor->buffer().dim[0].stride * core->bytes;
// Compute cord
for (auto h = 0; h < outH; ++h) {
auto __gridPtr = _gridPtr + h * gridTensor->buffer().dim[1].stride;
auto cordH = cordPtr + h * outW * 2;
for (auto w = 0; w < outW; ++w) {
auto x = getPosition(__gridPtr[2 * w + 0], inW, mAlignCorners);
auto y = getPosition(__gridPtr[2 * w + 1], inH, mAlignCorners);
cordH[2 * w + 0] = x;
cordH[2 * w + 1] = y;
}
}
core->MNNGridSampleComputeCord((float *)cordPtr, (const float *)_gridPtr, inH, inW, outH, outW, gridTensor->buffer().dim[1].stride, mAlignCorners);
MNN_CONCURRENCY_BEGIN(tId, threadCount) {
for (int index=tId; index < tileCount; index += threadCount) {
auto c = index / outH;
auto h = index % outH;
auto inpC = _inputPtr + c * inW * inH * 4;
auto outC = _outputPtr + c * outW * outH * 4;
auto cordH = cordPtr + h * outW * 2;
auto outH = outC + h * outW * 4;
for (auto w = 0; w < outW; ++w) {
auto x = cordH[2 * w + 0];
auto y = cordH[2 * w + 1];
Vec4::save(outH + 4 * w, interpolate(y, x, inpC, inH, inW, mMode, mPaddingMode));
}
auto inputC = _inputPtr + c * inW * inH * core->pack * core->bytes;
auto outputC = _outputPtr + c * outW * outH * core->pack * core->bytes;
auto cordH = cordPtr + h * outW * 2 * core->bytes;
auto outputH = outputC + h * outW * core->pack * core->bytes;
core->MNNGridSampleInterp((float *)outputH, (const float *)inputC, (const float *)cordH, inH, inW, outW, (mMode == SampleMode_NEAREST), (mPaddingMode == BorderMode_ZEROS));
}
}
MNN_CONCURRENCY_END();

View File

@ -1459,6 +1459,71 @@ void MNNPowC8(float* dest, const float* source, const float* powfParam, size_t b
#endif // no MNN_USE_NEON
void MNNGridSampleComputeCord(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners) {
float a = alignCorners ? 1.0f : 0.0f;
float b = alignCorners ? 0.0f : 1.0f;
for (auto h = 0; h < outH; ++h) {
auto __gridPtr = src + h * stride;
auto cordH = dst + h * outW * 2;
for (auto w = 0; w < outW; ++w) {
auto x = __gridPtr[2 * w + 0];
auto y = __gridPtr[2 * w + 1];
cordH[2 * w + 0] = ((1 + x) * (inW - a) - b) * 0.5f;
cordH[2 * w + 1] = ((1 + y) * (inH - a) - b) * 0.5f;
}
}
}
Vec4 MNNGridSampleLoadSample(int h, int w, const float *buffer, int height, int width, bool padMode) {
if (h < 0 || h >= height || w < 0 || w >= width) {
if(padMode == true) { //padMode == BorderMode_ZEROS
return 0.0f;
}
// Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
// For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
// the leftover reflections degrade to GridSamplePaddingMode_BORDER
h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h);
w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w);
}
return Vec4::load(buffer + h * width * 4 + w * 4);
}
void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode) {
for (auto ow = 0; ow < outW; ++ow) {
auto w = cordPtr[2 * ow + 0];
auto h = cordPtr[2 * ow + 1];
Vec4 interp;
if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
int nh = ::floor(h + 0.5f);
int nw = ::floor(w + 0.5f);
interp = MNNGridSampleLoadSample(nh, nw, inputPtr, inH, inW, padMode);
} else { //sampleMode == GridSampleMode_BILINEAR
int w0_h = ::floor(h);
int w0_w = ::floor(w);
int w1_h = ::ceil(h);
int w1_w = ::ceil(w);
auto oneV = Vec4(1.0f);
Vec4 i00 = MNNGridSampleLoadSample(w0_h, w0_w, inputPtr, inH, inW, padMode);
Vec4 i01 = MNNGridSampleLoadSample(w0_h, w1_w, inputPtr, inH, inW, padMode);
Vec4 i10 = MNNGridSampleLoadSample(w1_h, w0_w, inputPtr, inH, inW, padMode);
Vec4 i11 = MNNGridSampleLoadSample(w1_h, w1_w, inputPtr, inH, inW, padMode);
auto f0 = Vec4((float)w1_w - w);
auto f1 = oneV - f0;
auto h0 = Vec4((float)w1_h - h);
auto h1 = oneV - h0;
Vec4 i0 = i00 * f0 + i01 * f1;
Vec4 i1 = i10 * f0 + i11 * f1;
interp = i0 * h0 + i1 * h1;
}
Vec4::save(outputPtr + 4 * ow, interp);
}
}
void MNNPackC4Uint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth) {
int z, x;
@ -2515,6 +2580,8 @@ void MNNCoreFunctionInit() {
gCoreFunction->MNNStrassenMergeCFunction = MNNStrassenMergeCFunction;
gCoreFunction->penalty = 1.5f;
gCoreFunction->MNNScaleAndAddBias = MNNScaleAndAddBias;
gCoreFunction->MNNGridSampleComputeCord = MNNGridSampleComputeCord;
gCoreFunction->MNNGridSampleInterp = MNNGridSampleInterp;
gCoreFunction->MNNAddC4WithStride = MNNAddC4WithStride;
gCoreFunction->MNNCopyC4WithStride = MNNCopyC4WithStride;

View File

@ -43,6 +43,9 @@ void MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const f
size_t biasNumber);
void MNNScaleAndAddBiasScalar(float* dst, const float* src, float bias, float alpha, size_t number);
void MNNGridSampleComputeCord(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners);
void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
void MNNUnpackTranspose(float* dst, const float* src, size_t area, size_t depth);
void MNNUnpackTransposeInt16(int16_t* dst, const int16_t* src, size_t area, size_t depth);
void MNNUnpackTransposeUint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth);
@ -191,6 +194,8 @@ struct CoreFunctions {
size_t bStride, size_t height);
void(*MNNStrassenMergeCFunction)(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub);
void(*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber);
void(*MNNGridSampleComputeCord)(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners);
void(*MNNGridSampleInterp)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
float penalty;
void(*MNNCopyC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);

View File

@ -357,6 +357,7 @@ bool OpCommonUtils::opCompabilityForLowp(const Op* op) {
case OpType_ReLU:
case OpType_ReLU6:
case OpType_PReLU:
case OpType_GridSample:
return true;
default:
break;