| 
									
										
										
										
											2020-02-13 19:09:39 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  Arm82Relu.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | //  Created by MNN on 2020/2/13.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							| 
									
										
										
										
											2020-02-13 19:09:39 +08:00
										 |  |  | //
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | #if defined(__ANDROID__) || defined(__aarch64__)
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-14 11:14:56 +08:00
										 |  |  | #include <limits>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | #include "Arm82Relu.hpp"
 | 
					
						
							|  |  |  | #include "Arm82Backend.hpp"
 | 
					
						
							|  |  |  | #include "Arm82OptFunc.hpp"
 | 
					
						
							| 
									
										
										
										
											2020-06-02 20:21:12 +08:00
										 |  |  | #include "core/Concurrency.h"
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  | #include "half.hpp"
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | #include <algorithm>
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | #include <arm_neon.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-14 11:14:56 +08:00
										 |  |  | static void _MNNArm82PReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 *slope, size_t length) { | 
					
						
							| 
									
										
										
										
											2020-06-02 20:21:12 +08:00
										 |  |  | #ifdef MNN_USE_NEON
 | 
					
						
							|  |  |  |     float16x8_t value_0 = vmovq_n_f16(0); | 
					
						
							|  |  |  |     float16x8_t slopeV  = vld1q_f16(slope); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int i = 0; i < length; ++i) { | 
					
						
							|  |  |  | #ifdef MNN_USE_NEON
 | 
					
						
							|  |  |  |         float16x8_t value        = vld1q_f16(src + i * ARMV82_CHANNEL_UNIT); | 
					
						
							|  |  |  |         float16x8_t mulSlope     = vmulq_f16(value, slopeV); | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         uint16x8_t lessThanZero = vcleq_f16(value, value_0); | 
					
						
							| 
									
										
										
										
											2020-06-02 20:21:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         vst1q_f16(dst + i * ARMV82_CHANNEL_UNIT, vbslq_f16(lessThanZero, mulSlope, value)); | 
					
						
							|  |  |  | #else
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int j = 0; j < ARMV82_CHANNEL_UNIT; ++j) { | 
					
						
							|  |  |  |             if (src[i * ARMV82_CHANNEL_UNIT + j] < 0) { | 
					
						
							|  |  |  |                 dst[i * ARMV82_CHANNEL_UNIT + j] = src[i * ARMV82_CHANNEL_UNIT + j] * slope[j]; | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 dst[i * ARMV82_CHANNEL_UNIT + j] = src[i * ARMV82_CHANNEL_UNIT + j]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  | static void _MNNArm82LeakyReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, const FLOAT16 slope, size_t length) { | 
					
						
							|  |  |  |     float16x8_t value_0 = vmovq_n_f16(0); | 
					
						
							|  |  |  |     float16x8_t slopeV  = vmovq_n_f16(slope); | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     auto lC8 = length / ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |     auto remain = length % ARMV82_CHANNEL_UNIT; | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     for (int i = 0; i < lC8; ++i) { | 
					
						
							|  |  |  |         float16x8_t value        = vld1q_f16(src); | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  |         float16x8_t mulSlope     = vmulq_f16(value, slopeV); | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         uint16x8_t lessThanZero = vcleq_f16(value, value_0); | 
					
						
							|  |  |  |         vst1q_f16(dst, vbslq_f16(lessThanZero, mulSlope, value)); | 
					
						
							|  |  |  |         src += ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |         dst += ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (remain > 0) { | 
					
						
							|  |  |  |         float16_t tempSrc[ARMV82_CHANNEL_UNIT]; | 
					
						
							|  |  |  |         float16_t tempDst[ARMV82_CHANNEL_UNIT]; | 
					
						
							|  |  |  |         ::memcpy(tempSrc, src, remain * sizeof(int16_t)); | 
					
						
							|  |  |  |         float16x8_t value        = vld1q_f16(tempSrc); | 
					
						
							|  |  |  |         float16x8_t mulSlope     = vmulq_f16(value, slopeV); | 
					
						
							|  |  |  |         uint16x8_t lessThanZero = vcleq_f16(value, value_0); | 
					
						
							|  |  |  |         vst1q_f16(tempDst, vbslq_f16(lessThanZero, mulSlope, value)); | 
					
						
							|  |  |  |         ::memcpy(dst, tempDst, remain * sizeof(int16_t)); | 
					
						
							| 
									
										
										
										
											2021-01-14 11:14:56 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-14 11:14:56 +08:00
										 |  |  | static void _MNNArm82ReluWithChannel(FLOAT16 *dst, const FLOAT16 *src, size_t length) { | 
					
						
							|  |  |  |     float16x8_t value_0 = vmovq_n_f16(0); | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     auto lC8 = length / ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |     auto remain = length % ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |     for (int i = 0; i < lC8; ++i) { | 
					
						
							|  |  |  |         float16x8_t value        = vld1q_f16(src); | 
					
						
							|  |  |  |         uint16x8_t lessThanZero = vcleq_f16(value, value_0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         vst1q_f16(dst, vbslq_f16(lessThanZero, value_0, value)); | 
					
						
							|  |  |  |         dst += ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |         src += ARMV82_CHANNEL_UNIT; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (remain > 0) { | 
					
						
							|  |  |  |         float16_t tempSrc[ARMV82_CHANNEL_UNIT]; | 
					
						
							|  |  |  |         float16_t tempDst[ARMV82_CHANNEL_UNIT]; | 
					
						
							|  |  |  |         ::memcpy(tempSrc, src, remain * sizeof(int16_t)); | 
					
						
							|  |  |  |         float16x8_t value        = vld1q_f16(tempSrc); | 
					
						
							|  |  |  |         uint16x8_t lessThanZero = vcleq_f16(value, value_0); | 
					
						
							|  |  |  |         vst1q_f16(tempDst, vbslq_f16(lessThanZero, value_0, value)); | 
					
						
							|  |  |  |         ::memcpy(dst, tempDst, remain * sizeof(int16_t)); | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | void Arm82Relu::reluWithSlopeChannel(float* dstO, const float* srcO, const float* slopeO, size_t sizeQuad, size_t depthQuad) { | 
					
						
							|  |  |  |     auto dst = (FLOAT16*)dstO; | 
					
						
							|  |  |  |     auto src = (const FLOAT16*)srcO; | 
					
						
							|  |  |  |     auto slope = (const FLOAT16*)slopeO; | 
					
						
							|  |  |  |     for (int z=0; z<depthQuad; ++z) { | 
					
						
							|  |  |  |         auto dstZ = dst + z * 8 * sizeQuad; | 
					
						
							|  |  |  |         auto srcZ = src + z * 8 * sizeQuad; | 
					
						
							|  |  |  |         auto slopeZ = slope + 8 * z; | 
					
						
							|  |  |  |         _MNNArm82PReluWithChannel(dstZ, srcZ, slopeZ, sizeQuad); | 
					
						
							| 
									
										
										
										
											2020-09-16 12:27:58 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-06-02 20:21:12 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | } // namespace MNN
 | 
					
						
							| 
									
										
										
										
											2020-05-06 11:12:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-02 20:21:12 +08:00
										 |  |  | #endif
 |