mirror of https://github.com/alibaba/MNN.git
commit
1a6cacc808
|
|
@ -108,6 +108,95 @@ static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT
|
|||
}
|
||||
}
|
||||
|
||||
static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners) {
|
||||
float16x8_t zero = vdupq_n_f16(0);
|
||||
float16x8_t one = vdupq_n_f16(1);
|
||||
float16x8_t half = vdupq_n_f16(0.5f);
|
||||
float16x8_t a = alignCorners ? one : zero;
|
||||
float16x8_t b = alignCorners ? zero : one;
|
||||
float16x8_t inW_sub_a = vsubq_f16(vdupq_n_f16(inW), a);
|
||||
float16x8_t inH_sub_a = vsubq_f16(vdupq_n_f16(inH), a);
|
||||
|
||||
int area = outH * outW;
|
||||
int areaC8 = area / 8;
|
||||
int areaRemain = area - areaC8 * 8;
|
||||
for (int i = 0; i < areaC8; ++i) {
|
||||
auto cordH = vld2q_f16(src);
|
||||
// float16x8_t x = cordH.val[0];
|
||||
// float16x8_t y = cordH.val[1];
|
||||
cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b));
|
||||
cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b));
|
||||
vst2q_f16(dst, cordH);
|
||||
|
||||
src += 16;
|
||||
dst += 16;
|
||||
}
|
||||
|
||||
for (int i = 0; i < areaRemain; ++i) {
|
||||
float16x8_t x = vdupq_n_f16(src[0]);
|
||||
float16x8_t y = vdupq_n_f16(src[1]);
|
||||
x = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, x), inW_sub_a), b));
|
||||
y = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, y), inH_sub_a), b));
|
||||
dst[0] = x[0];
|
||||
dst[1] = y[0];
|
||||
|
||||
src += 2;
|
||||
dst += 2;
|
||||
}
|
||||
}
|
||||
|
||||
static Vec MNNGridSampleLoadSampleFP16(int h, int w, const FLOAT16 *buffer, int height, int width, bool padMode) {
|
||||
if (h < 0 || h >= height || w < 0 || w >= width) {
|
||||
if(padMode == true) { //padMode == BorderMode_ZEROS
|
||||
return (FLOAT16)0;
|
||||
}
|
||||
// Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
|
||||
// For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
|
||||
// the leftover reflections degrade to GridSamplePaddingMode_BORDER
|
||||
h = h < 0 ? 0 : (h > (height - 1) ? (height - 1) : h);
|
||||
w = w < 0 ? 0 : (w > (width - 1) ? (width - 1) : w);
|
||||
}
|
||||
return Vec::load(buffer + h * width * 8 + w * 8);
|
||||
}
|
||||
|
||||
static void MNNGridSampleInterpFP16(FLOAT16* outputPtr, const FLOAT16* inputPtr, const FLOAT16* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode) {
|
||||
for (auto ow = 0; ow < outW; ++ow) {
|
||||
auto w_fp16 = cordPtr[2 * ow + 0];
|
||||
auto h_fp16 = cordPtr[2 * ow + 1];
|
||||
float w = (float)(w_fp16);
|
||||
float h = (float)(h_fp16);
|
||||
Vec interp;
|
||||
|
||||
if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
|
||||
int nh = vcvtms_s32_f32(h + 0.5f);
|
||||
int nw = vcvtms_s32_f32(w + 0.5f);
|
||||
interp = MNNGridSampleLoadSampleFP16(nh, nw, inputPtr, inH, inW, padMode);
|
||||
} else { //sampleMode == GridSampleMode_BILINEAR
|
||||
int w0_h = vcvtms_s32_f32(h);
|
||||
int w0_w = vcvtms_s32_f32(w);
|
||||
int w1_h = vcvtps_s32_f32(h);
|
||||
int w1_w = vcvtps_s32_f32(w);
|
||||
auto oneV = Vec((FLOAT16)1);
|
||||
|
||||
Vec i00 = MNNGridSampleLoadSampleFP16(w0_h, w0_w, inputPtr, inH, inW, padMode);
|
||||
Vec i01 = MNNGridSampleLoadSampleFP16(w0_h, w1_w, inputPtr, inH, inW, padMode);
|
||||
Vec i10 = MNNGridSampleLoadSampleFP16(w1_h, w0_w, inputPtr, inH, inW, padMode);
|
||||
Vec i11 = MNNGridSampleLoadSampleFP16(w1_h, w1_w, inputPtr, inH, inW, padMode);
|
||||
auto f0 = Vec((FLOAT16)w1_w - w_fp16);
|
||||
auto f1 = oneV - f0;
|
||||
auto h0 = Vec((FLOAT16)w1_h - h_fp16);
|
||||
auto h1 = oneV - h0;
|
||||
|
||||
Vec i0 = i00 * f0 + i01 * f1;
|
||||
Vec i1 = i10 * f0 + i11 * f1;
|
||||
|
||||
interp = i0 * h0 + i1 * h1;
|
||||
}
|
||||
|
||||
Vec::save(outputPtr + 8 * ow, interp);
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNCopyC8WithStrideFP16(const FLOAT16* source, FLOAT16* dest, size_t srcStride, size_t dstStride, size_t count) {
|
||||
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
|
|
@ -536,6 +625,8 @@ bool Arm82Functions::init() {
|
|||
FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge);
|
||||
gInstance->penalty = 2.0f;
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp, MNNGridSampleInterpFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNCopyC4WithStride, MNNCopyC8WithStrideFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);
|
||||
|
||||
|
|
|
|||
|
|
@ -26,30 +26,70 @@ void MNNNC8HW8TONC4HW4(float* dest, const FLOAT16* source, size_t plane, size_t
|
|||
|
||||
template <typename TIN, typename TOUT, int UNIT>
|
||||
void MNNPackUNIT(TOUT* dst, const TIN* src, size_t area, size_t depth) {
|
||||
int z, x;
|
||||
int cur = 0;
|
||||
memset(dst, 0, area * UP_DIV(depth, UNIT) * UNIT * sizeof(TOUT));
|
||||
for (z = 0; z < depth; ++z) {
|
||||
int plane = z / UNIT;
|
||||
TOUT* dstPlane = plane * area * UNIT + dst;
|
||||
int offset = z % UNIT;
|
||||
for (x = 0; x < area; ++x) {
|
||||
dstPlane[UNIT * x + offset] = TOUT(src[cur++]);
|
||||
int depthCUnit = depth / UNIT;
|
||||
int depthRemain = depthCUnit * UNIT;
|
||||
int remain = depth - depthRemain;
|
||||
int z, x, y;
|
||||
const TIN* srcChannel[UNIT];
|
||||
const TIN* srcOffset = src;
|
||||
for(z = 0; z < depthCUnit; ++z) {
|
||||
for(y = 0; y < UNIT; ++y) {
|
||||
srcChannel[y] = srcOffset + area * y;
|
||||
}
|
||||
for(x = 0; x < area; ++x) {
|
||||
for(y = 0; y < UNIT; ++y) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y]++;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
srcOffset += area * UNIT;
|
||||
}
|
||||
if(remain > 0){
|
||||
for(y = 0; y < remain; ++y) {
|
||||
srcChannel[y] = srcOffset + area * y;
|
||||
}
|
||||
for(x = 0; x < area; ++x) {
|
||||
for(y = 0; y < remain; ++y) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y]++;
|
||||
dst++;
|
||||
}
|
||||
for(y = remain; y < UNIT; ++y) {
|
||||
dst[0] = 0;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TIN, typename TOUT, int UNIT>
|
||||
void MNNUnpackUNIT(TOUT* dst, const TIN* src, size_t area, size_t depth) {
|
||||
int x;
|
||||
int z;
|
||||
int cur = 0;
|
||||
for (z = 0; z < depth; ++z) {
|
||||
int plane = z / UNIT;
|
||||
const TIN* srcPlane = plane * area * UNIT + src;
|
||||
int offset = z % UNIT;
|
||||
for (x = 0; x < area; ++x) {
|
||||
dst[cur++] = TOUT(srcPlane[UNIT * x + offset]);
|
||||
int depthCUnit = depth / UNIT;
|
||||
int depthRemain = depthCUnit * UNIT;
|
||||
int remain = depth - depthRemain;
|
||||
int z, x, y;
|
||||
const TIN* srcChannel[UNIT];
|
||||
const TIN* srcOffset = src;
|
||||
for(z = 0; z < depthCUnit; ++z) {
|
||||
for(y = 0; y < UNIT; ++y) {
|
||||
srcChannel[y] = srcOffset + y;
|
||||
for(x = 0; x < area; ++x) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y] += UNIT;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
srcOffset += area * UNIT;
|
||||
}
|
||||
if(remain > 0){
|
||||
for(y = 0; y < remain; ++y) {
|
||||
srcChannel[y] = srcOffset + y;
|
||||
for(x = 0; x < area; ++x) {
|
||||
dst[0] = TOUT(srcChannel[y][0]);
|
||||
srcChannel[y] += UNIT;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,71 +26,13 @@ CPUGridSample::CPUGridSample(Backend *b, SampleMode mode, BorderMode paddingMode
|
|||
mAlignCorners = alignCorners;
|
||||
}
|
||||
|
||||
static float getPosition(float x, int range, bool alignCorners) {
|
||||
float a = alignCorners ? 1.0f : 0.0f;
|
||||
float b = alignCorners ? 0.0f : 1.0f;
|
||||
return ((1 + x) * (range - a) - b) / 2.0f;
|
||||
}
|
||||
|
||||
static int CLAMP(int v, int min, int max) {
|
||||
if ((v) < min) {
|
||||
(v) = min;
|
||||
} else if ((v) > max) {
|
||||
(v) = max;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
static Vec4 sample(int h, int w, const float *buffer, int height, int width, BorderMode padMode) {
|
||||
if (h < 0 || h >= height || w < 0 || w >= width) {
|
||||
if(padMode == BorderMode_ZEROS) {
|
||||
return 0.0f;
|
||||
}
|
||||
// Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
|
||||
// For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
|
||||
// the leftover reflections degrade to GridSamplePaddingMode_BORDER
|
||||
h = CLAMP(h, 0, height - 1);
|
||||
w = CLAMP(w, 0, width - 1);
|
||||
}
|
||||
|
||||
return Vec4::load(buffer + h * width * 4 + w * 4);
|
||||
}
|
||||
|
||||
static Vec4 interpolate(float h, float w, const float *buffer, int height, int width, SampleMode mode, BorderMode padMode) {
|
||||
if (mode == SampleMode_NEAREST) {
|
||||
int nh = ::floor(h+0.5f);
|
||||
int nw = ::floor(w+0.5f);
|
||||
return sample(nh, nw, buffer, height, width, padMode);
|
||||
}
|
||||
// mode == GridSampleMode_BILINEAR
|
||||
int w0_h = ::floor(h);
|
||||
int w0_w = ::floor(w);
|
||||
int w1_h = ::ceil(h);
|
||||
int w1_w = ::ceil(w);
|
||||
auto oneV = Vec4(1.0f);
|
||||
|
||||
Vec4 i00 = sample(w0_h, w0_w, buffer, height, width, padMode);
|
||||
Vec4 i01 = sample(w0_h, w1_w, buffer, height, width, padMode);
|
||||
Vec4 i10 = sample(w1_h, w0_w, buffer, height, width, padMode);
|
||||
Vec4 i11 = sample(w1_h, w1_w, buffer, height, width, padMode);
|
||||
auto f0 = Vec4((float)w1_w - w);
|
||||
auto f1 = oneV - f0;
|
||||
auto h0 = Vec4((float)w1_h - h);
|
||||
auto h1 = oneV - h0;
|
||||
|
||||
Vec4 i0 = i00 * f0 + i01 * f1;
|
||||
Vec4 i1 = i10 * f0 + i11 * f1;
|
||||
|
||||
return i0 * h0 + i1 * h1;
|
||||
}
|
||||
|
||||
|
||||
ErrorCode CPUGridSample::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
int numberThread = static_cast<CPUBackend*>(backend())->threadNumber();
|
||||
auto core = static_cast<CPUBackend*>(backend())->functions();
|
||||
auto outputTensor = outputs[0];
|
||||
auto outH = outputTensor->buffer().dim[2].extent;
|
||||
auto outW = outputTensor->buffer().dim[3].extent;
|
||||
mTempCordBuffer.reset(Tensor::createDevice<float>({1, outH * outW * 2}));
|
||||
mTempCordBuffer.reset(Tensor::createDevice<uint8_t>({1, outH * outW * 2 * core->bytes}));
|
||||
auto res = backend()->onAcquireBuffer(mTempCordBuffer.get(), Backend::DYNAMIC);
|
||||
if (!res) {
|
||||
return OUT_OF_MEMORY;
|
||||
|
|
@ -103,49 +45,37 @@ ErrorCode CPUGridSample::onExecute(const std::vector<Tensor *> &inputs, const st
|
|||
auto inputTensor = inputs[0];
|
||||
auto gridTensor = inputs[1];
|
||||
auto outputTensor = outputs[0];
|
||||
|
||||
float *inputPtr = inputTensor->host<float>();
|
||||
float *gridPtr = gridTensor->host<float>();
|
||||
auto *outputPtr = outputTensor->host<float>();
|
||||
auto core = static_cast<CPUBackend*>(backend())->functions();
|
||||
|
||||
auto inputPtr = inputTensor->host<uint8_t>();
|
||||
auto gridPtr = gridTensor->host<uint8_t>();
|
||||
auto outputPtr = outputTensor->host<uint8_t>();
|
||||
|
||||
auto batches = inputTensor->buffer().dim[0].extent;
|
||||
auto channels = inputTensor->buffer().dim[1].extent;
|
||||
auto channelC4 = UP_DIV(channels, 4);
|
||||
auto channelCUnit = UP_DIV(channels, core->pack);
|
||||
auto inH = inputTensor->buffer().dim[2].extent;
|
||||
auto inW = inputTensor->buffer().dim[3].extent;
|
||||
auto outH = outputTensor->buffer().dim[2].extent;
|
||||
auto outW = outputTensor->buffer().dim[3].extent;
|
||||
auto cordPtr = mTempCordBuffer->host<float>();
|
||||
auto cordPtr = mTempCordBuffer->host<uint8_t>();
|
||||
auto threadCount = static_cast<CPUBackend*>(backend())->threadNumber();
|
||||
auto tileCount = channelC4 * outH;
|
||||
auto tileCount = channelCUnit * outH;
|
||||
for (auto b = 0; b < batches; ++b) {
|
||||
const float *_inputPtr = inputPtr + b * inputTensor->buffer().dim[0].stride;
|
||||
const float *_gridPtr = gridPtr + b * gridTensor->buffer().dim[0].stride;
|
||||
float *_outputPtr = outputPtr + b * outputTensor->buffer().dim[0].stride;
|
||||
auto _inputPtr = inputPtr + b * inputTensor->buffer().dim[0].stride * core->bytes;
|
||||
auto _gridPtr = gridPtr + b * gridTensor->buffer().dim[0].stride * core->bytes;
|
||||
auto _outputPtr = outputPtr + b * outputTensor->buffer().dim[0].stride * core->bytes;
|
||||
// Compute cord
|
||||
for (auto h = 0; h < outH; ++h) {
|
||||
auto __gridPtr = _gridPtr + h * gridTensor->buffer().dim[1].stride;
|
||||
auto cordH = cordPtr + h * outW * 2;
|
||||
for (auto w = 0; w < outW; ++w) {
|
||||
auto x = getPosition(__gridPtr[2 * w + 0], inW, mAlignCorners);
|
||||
auto y = getPosition(__gridPtr[2 * w + 1], inH, mAlignCorners);
|
||||
cordH[2 * w + 0] = x;
|
||||
cordH[2 * w + 1] = y;
|
||||
}
|
||||
}
|
||||
core->MNNGridSampleComputeCord((float *)cordPtr, (const float *)_gridPtr, inH, inW, outH, outW, gridTensor->buffer().dim[1].stride, mAlignCorners);
|
||||
MNN_CONCURRENCY_BEGIN(tId, threadCount) {
|
||||
for (int index=tId; index < tileCount; index += threadCount) {
|
||||
auto c = index / outH;
|
||||
auto h = index % outH;
|
||||
auto inpC = _inputPtr + c * inW * inH * 4;
|
||||
auto outC = _outputPtr + c * outW * outH * 4;
|
||||
auto cordH = cordPtr + h * outW * 2;
|
||||
auto outH = outC + h * outW * 4;
|
||||
for (auto w = 0; w < outW; ++w) {
|
||||
auto x = cordH[2 * w + 0];
|
||||
auto y = cordH[2 * w + 1];
|
||||
Vec4::save(outH + 4 * w, interpolate(y, x, inpC, inH, inW, mMode, mPaddingMode));
|
||||
}
|
||||
auto inputC = _inputPtr + c * inW * inH * core->pack * core->bytes;
|
||||
auto outputC = _outputPtr + c * outW * outH * core->pack * core->bytes;
|
||||
auto cordH = cordPtr + h * outW * 2 * core->bytes;
|
||||
auto outputH = outputC + h * outW * core->pack * core->bytes;
|
||||
core->MNNGridSampleInterp((float *)outputH, (const float *)inputC, (const float *)cordH, inH, inW, outW, (mMode == SampleMode_NEAREST), (mPaddingMode == BorderMode_ZEROS));
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
|
|
|
|||
|
|
@ -1459,6 +1459,71 @@ void MNNPowC8(float* dest, const float* source, const float* powfParam, size_t b
|
|||
|
||||
#endif // no MNN_USE_NEON
|
||||
|
||||
void MNNGridSampleComputeCord(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners) {
|
||||
float a = alignCorners ? 1.0f : 0.0f;
|
||||
float b = alignCorners ? 0.0f : 1.0f;
|
||||
for (auto h = 0; h < outH; ++h) {
|
||||
auto __gridPtr = src + h * stride;
|
||||
auto cordH = dst + h * outW * 2;
|
||||
for (auto w = 0; w < outW; ++w) {
|
||||
auto x = __gridPtr[2 * w + 0];
|
||||
auto y = __gridPtr[2 * w + 1];
|
||||
cordH[2 * w + 0] = ((1 + x) * (inW - a) - b) * 0.5f;
|
||||
cordH[2 * w + 1] = ((1 + y) * (inH - a) - b) * 0.5f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Vec4 MNNGridSampleLoadSample(int h, int w, const float *buffer, int height, int width, bool padMode) {
|
||||
if (h < 0 || h >= height || w < 0 || w >= width) {
|
||||
if(padMode == true) { //padMode == BorderMode_ZEROS
|
||||
return 0.0f;
|
||||
}
|
||||
// Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
|
||||
// For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
|
||||
// the leftover reflections degrade to GridSamplePaddingMode_BORDER
|
||||
h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h);
|
||||
w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w);
|
||||
}
|
||||
|
||||
return Vec4::load(buffer + h * width * 4 + w * 4);
|
||||
}
|
||||
|
||||
void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode) {
|
||||
for (auto ow = 0; ow < outW; ++ow) {
|
||||
auto w = cordPtr[2 * ow + 0];
|
||||
auto h = cordPtr[2 * ow + 1];
|
||||
Vec4 interp;
|
||||
|
||||
if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
|
||||
int nh = ::floor(h + 0.5f);
|
||||
int nw = ::floor(w + 0.5f);
|
||||
interp = MNNGridSampleLoadSample(nh, nw, inputPtr, inH, inW, padMode);
|
||||
} else { //sampleMode == GridSampleMode_BILINEAR
|
||||
int w0_h = ::floor(h);
|
||||
int w0_w = ::floor(w);
|
||||
int w1_h = ::ceil(h);
|
||||
int w1_w = ::ceil(w);
|
||||
auto oneV = Vec4(1.0f);
|
||||
|
||||
Vec4 i00 = MNNGridSampleLoadSample(w0_h, w0_w, inputPtr, inH, inW, padMode);
|
||||
Vec4 i01 = MNNGridSampleLoadSample(w0_h, w1_w, inputPtr, inH, inW, padMode);
|
||||
Vec4 i10 = MNNGridSampleLoadSample(w1_h, w0_w, inputPtr, inH, inW, padMode);
|
||||
Vec4 i11 = MNNGridSampleLoadSample(w1_h, w1_w, inputPtr, inH, inW, padMode);
|
||||
auto f0 = Vec4((float)w1_w - w);
|
||||
auto f1 = oneV - f0;
|
||||
auto h0 = Vec4((float)w1_h - h);
|
||||
auto h1 = oneV - h0;
|
||||
|
||||
Vec4 i0 = i00 * f0 + i01 * f1;
|
||||
Vec4 i1 = i10 * f0 + i11 * f1;
|
||||
|
||||
interp = i0 * h0 + i1 * h1;
|
||||
}
|
||||
|
||||
Vec4::save(outputPtr + 4 * ow, interp);
|
||||
}
|
||||
}
|
||||
|
||||
void MNNPackC4Uint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth) {
|
||||
int z, x;
|
||||
|
|
@ -2515,6 +2580,8 @@ void MNNCoreFunctionInit() {
|
|||
gCoreFunction->MNNStrassenMergeCFunction = MNNStrassenMergeCFunction;
|
||||
gCoreFunction->penalty = 1.5f;
|
||||
gCoreFunction->MNNScaleAndAddBias = MNNScaleAndAddBias;
|
||||
gCoreFunction->MNNGridSampleComputeCord = MNNGridSampleComputeCord;
|
||||
gCoreFunction->MNNGridSampleInterp = MNNGridSampleInterp;
|
||||
gCoreFunction->MNNAddC4WithStride = MNNAddC4WithStride;
|
||||
gCoreFunction->MNNCopyC4WithStride = MNNCopyC4WithStride;
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ void MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const f
|
|||
size_t biasNumber);
|
||||
void MNNScaleAndAddBiasScalar(float* dst, const float* src, float bias, float alpha, size_t number);
|
||||
|
||||
void MNNGridSampleComputeCord(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners);
|
||||
void MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
|
||||
|
||||
void MNNUnpackTranspose(float* dst, const float* src, size_t area, size_t depth);
|
||||
void MNNUnpackTransposeInt16(int16_t* dst, const int16_t* src, size_t area, size_t depth);
|
||||
void MNNUnpackTransposeUint8(uint8_t* dst, const uint8_t* src, size_t area, size_t depth);
|
||||
|
|
@ -191,6 +194,8 @@ struct CoreFunctions {
|
|||
size_t bStride, size_t height);
|
||||
void(*MNNStrassenMergeCFunction)(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub);
|
||||
void(*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber);
|
||||
void(*MNNGridSampleComputeCord)(float* dst, const float* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners);
|
||||
void(*MNNGridSampleInterp)(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
|
||||
float penalty;
|
||||
|
||||
void(*MNNCopyC4WithStride)(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
|
||||
|
|
|
|||
|
|
@ -357,6 +357,7 @@ bool OpCommonUtils::opCompabilityForLowp(const Op* op) {
|
|||
case OpType_ReLU:
|
||||
case OpType_ReLU6:
|
||||
case OpType_PReLU:
|
||||
case OpType_GridSample:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue