mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			286 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			286 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  Arm82Binary.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2021/01/05.
 | |
| //  Copyright © 2021, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #if defined(__ANDROID__) || defined(__aarch64__)
 | |
| #include <algorithm>
 | |
| #include "backend/arm82/Arm82Binary.hpp"
 | |
| #include "backend/arm82/Arm82Backend.hpp"
 | |
| #include "backend/cpu/BinaryUtils.hpp"
 | |
| #include "core/Macro.h"
 | |
| 
 | |
| #include <arm_neon.h>
 | |
| 
 | |
| namespace MNN {
 | |
| template<typename Func>
 | |
| void Arm82BinaryWrap(void *dstRaw, const void *src0Raw, const void *src1Raw, const int elementSize, const int needBroadcastIndex) {
 | |
|     auto dst = (FLOAT16*)dstRaw;
 | |
|     auto src0 = (const FLOAT16*)src0Raw;
 | |
|     auto src1 = (const FLOAT16*)src1Raw;
 | |
|     Func compute;
 | |
|     const int sizeDivUnit = elementSize / 4;
 | |
|     const int remainCount = elementSize - sizeDivUnit * 4;
 | |
| 
 | |
|     float A[4];
 | |
|     float B[4];
 | |
|     float C[4];
 | |
|     if (-1 == needBroadcastIndex) {
 | |
|         if (sizeDivUnit > 0) {
 | |
|             for (int i = 0; i < sizeDivUnit; ++i) {
 | |
|                 const auto src0Ptr = src0;
 | |
|                 const auto src1Ptr = src1;
 | |
|                 auto dstPtr = dst;
 | |
|                 vst1q_f32(A, vcvt_f32_f16(vld1_f16(src0Ptr)));
 | |
|                 vst1q_f32(B, vcvt_f32_f16(vld1_f16(src1Ptr)));
 | |
|                 for (int v = 0; v < 4; ++ v) {
 | |
|                     C[v] = compute(A[v], B[v]);
 | |
|                 }
 | |
|                 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
 | |
|                 src0 += 4;
 | |
|                 src1 += 4;
 | |
|                 dst += 4;
 | |
|             }
 | |
|         }
 | |
|         if (remainCount > 0) {
 | |
|             FLOAT16 tempSrc0[4];
 | |
|             FLOAT16 tempSrc1[4];
 | |
|             FLOAT16 tempDst[4];
 | |
|             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
 | |
|             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
 | |
|             vst1q_f32(A, vcvt_f32_f16(vld1_f16(tempSrc0)));
 | |
|             vst1q_f32(B, vcvt_f32_f16(vld1_f16(tempSrc1)));
 | |
|             for (int v = 0; v < remainCount; ++ v) {
 | |
|                 C[v] = compute(A[v], B[v]);
 | |
|             }
 | |
|             vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
 | |
|             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
 | |
|         }
 | |
|     } else if (0 == needBroadcastIndex) {
 | |
|         const FLOAT16 srcValue0 = src0[0];
 | |
|         float16x4_t a = vmov_n_f16(srcValue0);
 | |
|         vst1q_f32(A, vcvt_f32_f16(a));
 | |
|         if (sizeDivUnit > 0) {
 | |
|             for (int i = 0; i < sizeDivUnit; ++i) {
 | |
|                 const auto src1Ptr = src1;
 | |
|                 auto dstPtr = dst;
 | |
|                 vst1q_f32(B, vcvt_f32_f16(vld1_f16(src1Ptr)));
 | |
|                 for (int v = 0; v < 4; ++ v) {
 | |
|                     C[v] = compute(A[v], B[v]);
 | |
|                 }
 | |
|                 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
 | |
|                 src1 += 4;
 | |
|                 dst += 4;
 | |
|             }
 | |
|         }
 | |
|         if (remainCount > 0) {
 | |
|             FLOAT16 tempSrc1[4];
 | |
|             FLOAT16 tempDst[4];
 | |
|             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
 | |
|             vst1q_f32(B, vcvt_f32_f16(vld1_f16(tempSrc1)));
 | |
|             for (int v = 0; v < remainCount; ++ v) {
 | |
|                 C[v] = compute(A[v], B[v]);
 | |
|             }
 | |
|             vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
 | |
|             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
 | |
|         }
 | |
|     } else {
 | |
|         const FLOAT16 srcValue1 = src1[0];
 | |
|         float16x4_t b = vmov_n_f16(srcValue1);
 | |
|         vst1q_f32(B, vcvt_f32_f16(b));
 | |
|         if (sizeDivUnit > 0) {
 | |
|             for (int i = 0; i < sizeDivUnit; ++i) {
 | |
|                 const auto src0Ptr = src0;
 | |
|                 auto dstPtr = dst;
 | |
|                 vst1q_f32(A, vcvt_f32_f16(vld1_f16(src0Ptr)));
 | |
|                 for (int v = 0; v < 4; ++ v) {
 | |
|                     C[v] = compute(A[v], B[v]);
 | |
|                 }
 | |
|                 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
 | |
|                 src0 += 4;
 | |
|                 dst += 4;
 | |
|             }
 | |
|         }
 | |
|         if (remainCount > 0) {
 | |
|             FLOAT16 tempSrc0[4];
 | |
|             FLOAT16 tempDst[4];
 | |
|             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
 | |
|             vst1q_f32(A, vcvt_f32_f16(vld1_f16(tempSrc0)));
 | |
|             for (int v = 0; v < remainCount; ++ v) {
 | |
|                 C[v] = compute(A[v], B[v]);
 | |
|             }
 | |
|             vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
 | |
|             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| 
 | |
| template<typename Func>
 | |
| void Arm82Binary(void *dstRaw, const void *src0Raw, const void *src1Raw, const int elementSize, const int needBroadcastIndex) {
 | |
|     auto dst = (FLOAT16*)dstRaw;
 | |
|     auto src0 = (FLOAT16*)src0Raw;
 | |
|     auto src1 = (FLOAT16*)src1Raw;
 | |
|     Func compute;
 | |
|     const int sizeDivUnit = elementSize / ARMV82_CHANNEL_UNIT;
 | |
|     const int remainCount = elementSize - sizeDivUnit * ARMV82_CHANNEL_UNIT;
 | |
| 
 | |
|     if (-1 == needBroadcastIndex) {
 | |
|         if (sizeDivUnit > 0) {
 | |
|             for (int i = 0; i < sizeDivUnit; ++i) {
 | |
|                 const auto src0Ptr = src0;
 | |
|                 const auto src1Ptr = src1;
 | |
|                 auto dstPtr = dst;
 | |
|                 float16x8_t a = vld1q_f16(src0Ptr);
 | |
|                 float16x8_t b = vld1q_f16(src1Ptr);
 | |
|                 vst1q_f16(dstPtr, compute(a, b));
 | |
|                 src0 += 8;
 | |
|                 src1 += 8;
 | |
|                 dst += 8;
 | |
|             }
 | |
|         }
 | |
|         if (remainCount > 0) {
 | |
|             FLOAT16 tempSrc0[8];
 | |
|             FLOAT16 tempSrc1[8];
 | |
|             FLOAT16 tempDst[8];
 | |
|             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
 | |
|             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
 | |
|             float16x8_t a = vld1q_f16(tempSrc0);
 | |
|             float16x8_t b = vld1q_f16(tempSrc1);
 | |
|             vst1q_f16(tempDst, compute(a, b));
 | |
|             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
 | |
|         }
 | |
|     } else if (0 == needBroadcastIndex) {
 | |
|         const FLOAT16 srcValue0 = src0[0];
 | |
|         float16x8_t a = vmovq_n_f16(srcValue0);
 | |
|         if (sizeDivUnit > 0) {
 | |
|             for (int i = 0; i < sizeDivUnit; ++i) {
 | |
|                 const auto src1Ptr = src1;
 | |
|                 auto dstPtr = dst;
 | |
|                 float16x8_t b = vld1q_f16(src1Ptr);
 | |
|                 vst1q_f16(dstPtr, compute(a, b));
 | |
|                 src1 += 8;
 | |
|                 dst += 8;
 | |
|             }
 | |
|         }
 | |
|         if (remainCount > 0) {
 | |
|             FLOAT16 tempSrc1[8];
 | |
|             FLOAT16 tempDst[8];
 | |
|             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
 | |
|             float16x8_t b = vld1q_f16(tempSrc1);
 | |
|             vst1q_f16(tempDst, compute(a, b));
 | |
|             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
 | |
|         }
 | |
|     } else {
 | |
|         const FLOAT16 srcValue1 = src1[0];
 | |
|         float16x8_t b = vmovq_n_f16(srcValue1);
 | |
|         if (sizeDivUnit > 0) {
 | |
|             for (int i = 0; i < sizeDivUnit; ++i) {
 | |
|                 const auto src0Ptr = src0;
 | |
|                 auto dstPtr = dst;
 | |
|                 float16x8_t a = vld1q_f16(src0Ptr);
 | |
|                 vst1q_f16(dstPtr, compute(a, b));
 | |
|                 src0 += 8;
 | |
|                 dst += 8;
 | |
|             }
 | |
|         }
 | |
|         if (remainCount > 0) {
 | |
|             FLOAT16 tempSrc0[8];
 | |
|             FLOAT16 tempDst[8];
 | |
|             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
 | |
|             float16x8_t a = vld1q_f16(tempSrc0);
 | |
|             vst1q_f16(tempDst, compute(a, b));
 | |
|             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| 
 | |
| struct VecBinaryAdd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
 | |
|     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
 | |
|         return vaddq_f16(x, y);
 | |
|     }
 | |
| };
 | |
| 
 | |
| struct VecBinarySub : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
 | |
|     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
 | |
|         return vsubq_f16(x, y);
 | |
|     }
 | |
| };
 | |
| 
 | |
| struct VecBinaryMul : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
 | |
|     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
 | |
|         return vmulq_f16(x, y);
 | |
|     }
 | |
| };
 | |
| 
 | |
| struct VecBinaryMin : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
 | |
|     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
 | |
|         return vminq_f16(x, y);
 | |
|     }
 | |
| };
 | |
| 
 | |
| struct VecBinaryMax : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
 | |
|     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
 | |
|         return vmaxq_f16(x, y);
 | |
|     }
 | |
| };
 | |
| 
 | |
| struct VecBinarySqd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
 | |
|     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
 | |
|         return vmulq_f16(vsubq_f16(x, y), vsubq_f16(x, y));
 | |
|     }
 | |
| };
 | |
| 
 | |
| 
 | |
| MNNBinaryExecute Arm82BinaryFloat::select(int32_t type) {
 | |
|     switch (type) {
 | |
|         case BinaryOpOperation_ADD:
 | |
|             return Arm82Binary<VecBinaryAdd>;
 | |
|             break;
 | |
|         case BinaryOpOperation_SUB:
 | |
|             return Arm82Binary<VecBinarySub>;
 | |
|             break;
 | |
|         case BinaryOpOperation_MUL:
 | |
|             return Arm82Binary<VecBinaryMul>;
 | |
|             break;
 | |
|         case BinaryOpOperation_MINIMUM:
 | |
|             return Arm82Binary<VecBinaryMin>;
 | |
|             break;
 | |
|         case BinaryOpOperation_MAXIMUM:
 | |
|             return Arm82Binary<VecBinaryMax>;
 | |
|             break;
 | |
|         case BinaryOpOperation_SquaredDifference:
 | |
|             return Arm82Binary<VecBinarySqd>;
 | |
|             break;
 | |
|         case BinaryOpOperation_REALDIV:
 | |
|             return Arm82BinaryWrap<BinaryRealDiv<float, float, float>>;
 | |
|             break;
 | |
|         case BinaryOpOperation_FLOORDIV:
 | |
|             return Arm82BinaryWrap<BinaryFloorDiv<float, float, float>>;
 | |
|             break;
 | |
|         case BinaryOpOperation_FLOORMOD:
 | |
|             return Arm82BinaryWrap<BinaryFloorMod<float, float, float>>;
 | |
|             break;
 | |
|         case BinaryOpOperation_POW:
 | |
|             return Arm82BinaryWrap<BinaryPow<float, float, float>>;
 | |
|             break;
 | |
|         case BinaryOpOperation_ATAN2:
 | |
|             return Arm82BinaryWrap<BinaryAtan2<float, float, float>>;
 | |
|             break;
 | |
|         case BinaryOpOperation_MOD:
 | |
|             return Arm82BinaryWrap<BinaryMod<float, float, float>>;
 | |
|             break;
 | |
|         default:
 | |
|             return nullptr;
 | |
|             break;
 | |
|     }
 | |
|     return nullptr;
 | |
| }
 | |
| 
 | |
| } // namespace MNN
 | |
| #endif
 |