diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.h b/source/backend/cpu/compute/Int8FunctionsOpt.h index 431b5fd6..a16a7b71 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.h +++ b/source/backend/cpu/compute/Int8FunctionsOpt.h @@ -63,6 +63,8 @@ struct QuanPostTreatParameters { const int32_t* bias; int32_t maxValue; int32_t minValue; + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; }; void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); diff --git a/source/backend/cpu/x86_x64/avx/GemmCommon.cpp b/source/backend/cpu/x86_x64/avx/GemmCommon.cpp index b250c46e..ace17813 100644 --- a/source/backend/cpu/x86_x64/avx/GemmCommon.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmCommon.cpp @@ -203,8 +203,19 @@ void AVX2GemmPostTreat(float* C, size_t eSize, const size_t* parameter, const fl } } } - +#ifdef MNN_X86_USE_ASM +extern "C" { +void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post); +} +#endif void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post) { +#ifdef MNN_X86_USE_ASM + size_t strides[3]; + strides[0] = src_depth_quad; + strides[1] = dst_step; + strides[2] = dst_depth_quad; + _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(dst, src, weight, strides, post); +#else const auto dst_step_tmp = dst_step / sizeof(int8_t); __m128 zero128 = _mm_set1_ps(0.0f); __m128 minValue = _mm_set1_ps(post->minValue); @@ -356,4 +367,5 @@ auto d##i = _mm_add_epi32(d##i##0, d##i##1); d0 = _mm_packs_epi16(d0, d2); _mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0)); } +#endif } diff --git a/source/backend/cpu/x86_x64/avx/_AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S b/source/backend/cpu/x86_x64/avx/_AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S new file mode 100644 index 00000000..bbda7d28 --- /dev/null +++ b/source/backend/cpu/x86_x64/avx/_AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S @@ -0,0 +1,309 @@ +// +// _AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S +// MNN +// +// Created by MNN on 2020/11/04. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#include "../MNNAsmGlobal.h" +.text +.align 4 + +//struct QuanPostTreatParameters { +// const float* scale; +// const int32_t* bias; +// int32_t maxValue; +// int32_t minValue; +// float roundValuePos = 0.5f; +// float roundValueNeg = -0.5f; +//}; + +asm_function _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain +//void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post); + + +// SystemV Auto: rdi: dst, rsi:src, rdx:weight, rcx:strides, r8: post +// Microsoft x64 Auto: rcx:C, rdx:A, r8:B, r9:parameter +pushq %rbp +movq %rsp, %rbp + +#ifdef WIN32 +movq 48(%rsp), %r10 +pushq %rdi +pushq %rsi +pushq %r12 +pushq %r13 +movq %rcx, %rdi +movq %rdx, %rsi +movq %r8, %rdx +movq %r9, %rcx +movq %r10, %r9 +pushq %r14 +pushq %r15 +#else +pushq %r12 +pushq %r13 +pushq %r14 +pushq %r15 +movq %r8, %r9 +#endif + +movq 8(%rcx), %r10 // dst_step +movq 16(%rcx), %r8 // dst_depth_quad +movq (%rcx), %rcx // src_depth_quad +movq (%r9), %r12 // scale +movq 8(%r9), %r15 // bias + + +// ymm0-ymm1: Src +// ymm2-ymm3: Weight +// ymm4-ymm7: TmpDst +// ymm8-ymm15: Dst Sum + +// Last dst save to ymm8-ymm11 + +cmpq $0, %r8 +je End + +movq %rsi, %r13 +subq $64, %rsp +LoopDz: + movq %rcx, %r11 + movq %r13, %rsi + movq %rdx, %r14 + subq $1, %r11 + vpmovsxbw (%rsi), %ymm0 + vpmovsxbw 16(%rsi), %ymm1 + vpmovsxbw (%rdx), %ymm2 + vpmovsxbw 16(%rdx), %ymm3 + + vpmaddwd %ymm0, %ymm2, %ymm8 + vpmaddwd %ymm0, %ymm3, %ymm9 + vpmaddwd %ymm1, %ymm2, %ymm12 + vpmaddwd %ymm1, %ymm3, %ymm13 + vpmovsxbw 32(%rdx), %ymm2 + vpmovsxbw 48(%rdx), %ymm3 + + vpmaddwd %ymm0, %ymm2, %ymm10 + vpmaddwd %ymm0, %ymm3, %ymm11 + vpmaddwd %ymm1, %ymm2, %ymm14 + vpmaddwd %ymm1, %ymm3, %ymm15 + addq $64, %rdx + addq $64, %rsi + + testq %r11, %r11 + je FirstLoopSzEnd + + FirstLoopSz: + vpmovsxbw (%rsi), %ymm0 + vpmovsxbw 16(%rsi), %ymm1 + vpmovsxbw (%rdx), %ymm2 + vpmovsxbw 16(%rdx), %ymm3 + + vpmaddwd %ymm0, %ymm2, %ymm4 + vpmaddwd %ymm0, %ymm3, %ymm5 + vpmaddwd %ymm1, %ymm2, %ymm6 + vpmaddwd %ymm1, %ymm3, %ymm7 + vpaddd %ymm4, %ymm8, %ymm8 + vpaddd %ymm5, %ymm9, %ymm9 + vpmovsxbw 32(%rdx), %ymm2 + vpmovsxbw 48(%rdx), %ymm3 + vpaddd %ymm6, %ymm12, %ymm12 + vpaddd %ymm7, %ymm13, %ymm13 + + + vpmaddwd %ymm0, %ymm2, %ymm4 + vpmaddwd %ymm0, %ymm3, %ymm5 + vpmaddwd %ymm1, %ymm2, %ymm6 + vpmaddwd %ymm1, %ymm3, %ymm7 + vpaddd %ymm4, %ymm10, %ymm10 + vpaddd %ymm5, %ymm11, %ymm11 + vpaddd %ymm6, %ymm14, %ymm14 + vpaddd %ymm7, %ymm15, %ymm15 + + addq $64, %rdx + addq $64, %rsi + + subq $1, %r11 + testq %r11, %r11 + jne FirstLoopSz + + FirstLoopSzEnd: + + vphaddd %ymm9, %ymm8, %ymm8 + vphaddd %ymm11, %ymm10, %ymm10 + vphaddd %ymm13, %ymm12, %ymm12 + vphaddd %ymm15, %ymm14, %ymm14 + + vphaddd %ymm10, %ymm8, %ymm8 + vphaddd %ymm14, %ymm12, %ymm9 + + vmovups %ymm8, (%rsp) + vmovups %ymm9, 32(%rsp) + + movq %rcx, %r11 + movq %r13, %rsi + movq %r14, %rdx + vpmovsxbw 32(%rsi), %ymm0 + vpmovsxbw 48(%rsi), %ymm1 + vpmovsxbw (%rdx), %ymm2 + vpmovsxbw 16(%rdx), %ymm3 + + vpmaddwd %ymm0, %ymm2, %ymm8 + vpmaddwd %ymm0, %ymm3, %ymm9 + vpmaddwd %ymm1, %ymm2, %ymm12 + vpmaddwd %ymm1, %ymm3, %ymm13 + + vpmovsxbw 32(%rdx), %ymm2 + vpmovsxbw 48(%rdx), %ymm3 + + vpmaddwd %ymm0, %ymm2, %ymm10 + vpmaddwd %ymm0, %ymm3, %ymm11 + vpmaddwd %ymm1, %ymm2, %ymm14 + vpmaddwd %ymm1, %ymm3, %ymm15 + + addq $64, %rdx + addq $64, %rsi + + subq $1, %r11 + testq %r11, %r11 + je SecondLoopSzEnd + + SecondLoopSz: + vpmovsxbw 32(%rsi), %ymm0 + vpmovsxbw 48(%rsi), %ymm1 + vpmovsxbw (%rdx), %ymm2 + vpmovsxbw 16(%rdx), %ymm3 + + vpmaddwd %ymm0, %ymm2, %ymm4 + vpmaddwd %ymm0, %ymm3, %ymm5 + vpmaddwd %ymm1, %ymm2, %ymm6 + vpmaddwd %ymm1, %ymm3, %ymm7 + vpaddd %ymm4, %ymm8, %ymm8 + vpaddd %ymm5, %ymm9, %ymm9 + vpmovsxbw 32(%rdx), %ymm2 + vpmovsxbw 48(%rdx), %ymm3 + vpaddd %ymm6, %ymm12, %ymm12 + vpaddd %ymm7, %ymm13, %ymm13 + + + vpmaddwd %ymm0, %ymm2, %ymm4 + vpmaddwd %ymm0, %ymm3, %ymm5 + vpmaddwd %ymm1, %ymm2, %ymm6 + vpmaddwd %ymm1, %ymm3, %ymm7 + vpaddd %ymm4, %ymm10, %ymm10 + vpaddd %ymm5, %ymm11, %ymm11 + vpaddd %ymm6, %ymm14, %ymm14 + vpaddd %ymm7, %ymm15, %ymm15 + + addq $64, %rdx + addq $64, %rsi + + subq $1, %r11 + testq %r11, %r11 + jne SecondLoopSz + SecondLoopSzEnd: + + vphaddd %ymm9, %ymm8, %ymm8 + vphaddd %ymm11, %ymm10, %ymm10 + vphaddd %ymm13, %ymm12, %ymm12 + vphaddd %ymm15, %ymm14, %ymm14 + + vphaddd %ymm10, %ymm8, %ymm10 + vphaddd %ymm14, %ymm12, %ymm11 + + vmovups (%rsp), %ymm8 + vmovups 32(%rsp), %ymm9 + + Last: +.macro TRANSPOSE x0, x1, x2, x3 + // 32 = 0 + 16 * 2: frist 128 x0_lo, second 128 x1_lo + // 49 = 1 + 16 * 3: frist 128 x0_hi, second 128 x1_hi + vperm2f128 $32, \x1, \x0, \x2 + vperm2f128 $49, \x1, \x0, \x3 +.endm + TRANSPOSE %ymm8, %ymm10, %ymm0, %ymm1 + TRANSPOSE %ymm9, %ymm11, %ymm2, %ymm3 + + vpaddd %ymm0, %ymm1, %ymm0 + vpaddd %ymm2, %ymm3, %ymm2 + + vbroadcastf128 (%r12), %ymm8 + vbroadcastf128 (%r15), %ymm9 + + vpaddd %ymm9, %ymm0, %ymm0 + vpaddd %ymm9, %ymm2, %ymm2 + + vcvtdq2ps %ymm0, %ymm0 + vcvtdq2ps %ymm2, %ymm2 + + vmulps %ymm8, %ymm0, %ymm0 + vmulps %ymm8, %ymm2, %ymm2 + // zero + vxorps %ymm13, %ymm13, %ymm13 + + vbroadcastss 24(%r9), %ymm14 + vbroadcastss 28(%r9), %ymm15 + vbroadcastss 16(%r9), %ymm10 + vbroadcastss 20(%r9), %ymm11 + + // Round + vcmpltps %ymm13, %ymm0, %ymm4 + vcmpltps %ymm13, %ymm2, %ymm5 + + vblendvps %ymm4, %ymm15, %ymm14, %ymm4 + vblendvps %ymm5, %ymm15, %ymm14, %ymm5 + + vaddps %ymm0, %ymm4, %ymm0 + vaddps %ymm2, %ymm5, %ymm2 + + // 3: ROUND to Zero + vroundps $3, %ymm0, %ymm0 + vroundps $3, %ymm2, %ymm2 + vcvtps2dq %ymm0, %ymm0 + vcvtps2dq %ymm2, %ymm2 + + vpminsd %ymm10, %ymm0, %ymm0 + vpminsd %ymm10, %ymm2, %ymm2 + + vpmaxsd %ymm11, %ymm0, %ymm0 + vpmaxsd %ymm11, %ymm2, %ymm2 + + vpackssdw %ymm2, %ymm0, %ymm0 + vperm2f128 $1, %ymm0, %ymm0, %ymm1 + vpacksswb %ymm1, %ymm0, %ymm0 + + addq $16, %r12 + addq $16, %r15 + + vmovups %xmm0, (%rdi) + addq %r10, %rdi + + subq $1, %r8 + testq %r8, %r8 + jne LoopDz +addq $64, %rsp + +End: + +#ifdef WIN32 +popq %r15 +popq %r14 +popq %r13 +popq %r12 +popq %rsi +popq %rdi +popq %rbp +#else +popq %r15 +popq %r14 +popq %r13 +popq %r12 +popq %rbp +#endif + +// FIXME: if don't vzeroall, it will cause other op slow +vzeroall +retq +