mirror of https://github.com/alibaba/MNN.git
114 lines
3.8 KiB
C++
114 lines
3.8 KiB
C++
//
|
|
// Arm82Relu.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/2/13.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
#if defined(__ANDROID__) || defined(__aarch64__)
|
|
|
|
#include <limits>
|
|
|
|
#include "Arm82Relu.hpp"
|
|
#include "Arm82Backend.hpp"
|
|
#include "Arm82OptFunc.hpp"
|
|
#include "core/Concurrency.h"
|
|
#include "core/Macro.h"
|
|
#include "half.hpp"
|
|
#include <algorithm>
|
|
#include <arm_neon.h>
|
|
|
|
namespace MNN {
|
|
|
|
static void _MNNArm82PReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 *slope, size_t length) {
|
|
#ifdef MNN_USE_NEON
|
|
float16x8_t value_0 = vmovq_n_f16(0);
|
|
float16x8_t slopeV = vld1q_f16(slope);
|
|
#endif
|
|
|
|
for (int i = 0; i < length; ++i) {
|
|
#ifdef MNN_USE_NEON
|
|
float16x8_t value = vld1q_f16(src + i * ARMV82_CHANNEL_UNIT);
|
|
float16x8_t mulSlope = vmulq_f16(value, slopeV);
|
|
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
|
|
|
vst1q_f16(dst + i * ARMV82_CHANNEL_UNIT, vbslq_f16(lessThanZero, mulSlope, value));
|
|
#else
|
|
|
|
for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) {
|
|
if (src[i * ARMV82_CHANNEL_UNIT + j] < 0) {
|
|
dst[i * ARMV82_CHANNEL_UNIT + j] = src[i * ARMV82_CHANNEL_UNIT + j] * slope[j];
|
|
} else {
|
|
dst[i * ARMV82_CHANNEL_UNIT + j] = src[i * ARMV82_CHANNEL_UNIT + j];
|
|
}
|
|
}
|
|
|
|
#endif
|
|
}
|
|
}
|
|
|
|
static void _MNNArm82LeakyReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 slope, size_t length) {
|
|
float16x8_t value_0 = vmovq_n_f16(0);
|
|
float16x8_t slopeV = vmovq_n_f16(slope);
|
|
auto lC8 = length / ARMV82_CHANNEL_UNIT;
|
|
auto remain = length % ARMV82_CHANNEL_UNIT;
|
|
|
|
for (int i = 0; i < lC8; ++i) {
|
|
float16x8_t value = vld1q_f16(src);
|
|
float16x8_t mulSlope = vmulq_f16(value, slopeV);
|
|
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
|
vst1q_f16(dst, vbslq_f16(lessThanZero, mulSlope, value));
|
|
src += ARMV82_CHANNEL_UNIT;
|
|
dst += ARMV82_CHANNEL_UNIT;
|
|
}
|
|
if (remain > 0) {
|
|
float16_t tempSrc[ARMV82_CHANNEL_UNIT];
|
|
float16_t tempDst[ARMV82_CHANNEL_UNIT];
|
|
::memcpy(tempSrc, src, remain * sizeof(int16_t));
|
|
float16x8_t value = vld1q_f16(tempSrc);
|
|
float16x8_t mulSlope = vmulq_f16(value, slopeV);
|
|
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
|
vst1q_f16(tempDst, vbslq_f16(lessThanZero, mulSlope, value));
|
|
::memcpy(dst, tempDst, remain * sizeof(int16_t));
|
|
}
|
|
}
|
|
|
|
static void _MNNArm82ReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, size_t length) {
|
|
float16x8_t value_0 = vmovq_n_f16(0);
|
|
auto lC8 = length / ARMV82_CHANNEL_UNIT;
|
|
auto remain = length % ARMV82_CHANNEL_UNIT;
|
|
for (int i = 0; i < lC8; ++i) {
|
|
float16x8_t value = vld1q_f16(src);
|
|
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
|
|
|
vst1q_f16(dst, vbslq_f16(lessThanZero, value_0, value));
|
|
dst += ARMV82_CHANNEL_UNIT;
|
|
src += ARMV82_CHANNEL_UNIT;
|
|
}
|
|
if (remain > 0) {
|
|
float16_t tempSrc[ARMV82_CHANNEL_UNIT];
|
|
float16_t tempDst[ARMV82_CHANNEL_UNIT];
|
|
::memcpy(tempSrc, src, remain * sizeof(int16_t));
|
|
float16x8_t value = vld1q_f16(tempSrc);
|
|
uint16x8_t lessThanZero = vcleq_f16(value, value_0);
|
|
vst1q_f16(tempDst, vbslq_f16(lessThanZero, value_0, value));
|
|
::memcpy(dst, tempDst, remain * sizeof(int16_t));
|
|
}
|
|
}
|
|
|
|
void Arm82Relu::reluWithSlopeChannel(float* dstO, const float* srcO, const float* slopeO, size_t sizeQuad, size_t depthQuad) {
|
|
auto dst = (FLOAT16*)dstO;
|
|
auto src = (const FLOAT16*)srcO;
|
|
auto slope = (const FLOAT16*)slopeO;
|
|
for (int z=0; z<depthQuad; ++z) {
|
|
auto dstZ = dst + z * 8 * sizeQuad;
|
|
auto srcZ = src + z * 8 * sizeQuad;
|
|
auto slopeZ = slope + 8 * z;
|
|
_MNNArm82PReluWithChannel(dstZ, srcZ, slopeZ, sizeQuad);
|
|
}
|
|
}
|
|
|
|
} // namespace MNN
|
|
|
|
#endif
|