mirror of https://github.com/alibaba/MNN.git
				
				
				
			[PATCH 25/78] [MNN:Speed] Add asm for avx int8
This commit is contained in:
		
							parent
							
								
									70016820a0
								
							
						
					
					
						commit
						1bd8d27131
					
				|  | @ -63,6 +63,8 @@ struct QuanPostTreatParameters { | ||||||
|     const int32_t* bias; |     const int32_t* bias; | ||||||
|     int32_t maxValue; |     int32_t maxValue; | ||||||
|     int32_t minValue; |     int32_t minValue; | ||||||
|  |     float roundValuePos = 0.5f; | ||||||
|  |     float roundValueNeg = -0.5f; | ||||||
| }; | }; | ||||||
| void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); | void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); | ||||||
| void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); | void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); | ||||||
|  |  | ||||||
|  | @ -203,8 +203,19 @@ void AVX2GemmPostTreat(float* C, size_t eSize, const size_t* parameter, const fl | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | #ifdef MNN_X86_USE_ASM | ||||||
|  | extern "C" { | ||||||
|  | void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post); | ||||||
|  | } | ||||||
|  | #endif | ||||||
| void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post) { | void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post) { | ||||||
|  | #ifdef MNN_X86_USE_ASM | ||||||
|  |     size_t strides[3]; | ||||||
|  |     strides[0] = src_depth_quad; | ||||||
|  |     strides[1] = dst_step; | ||||||
|  |     strides[2] = dst_depth_quad; | ||||||
|  |     _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(dst, src, weight, strides, post); | ||||||
|  | #else | ||||||
|     const auto dst_step_tmp = dst_step / sizeof(int8_t); |     const auto dst_step_tmp = dst_step / sizeof(int8_t); | ||||||
|     __m128 zero128 = _mm_set1_ps(0.0f); |     __m128 zero128 = _mm_set1_ps(0.0f); | ||||||
|     __m128 minValue = _mm_set1_ps(post->minValue); |     __m128 minValue = _mm_set1_ps(post->minValue); | ||||||
|  | @ -356,4 +367,5 @@ auto d##i = _mm_add_epi32(d##i##0, d##i##1); | ||||||
|         d0 = _mm_packs_epi16(d0, d2); |         d0 = _mm_packs_epi16(d0, d2); | ||||||
|         _mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0)); |         _mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0)); | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,309 @@ | ||||||
|  | // | ||||||
|  | //  _AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S | ||||||
|  | //  MNN | ||||||
|  | // | ||||||
|  | //  Created by MNN on 2020/11/04. | ||||||
|  | //  Copyright © 2018, Alibaba Group Holding Limited | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include "../MNNAsmGlobal.h" | ||||||
|  | .text | ||||||
|  | .align 4
 | ||||||
|  | 
 | ||||||
|  | //struct QuanPostTreatParameters { | ||||||
|  | //    const float* scale;
 | ||||||
|  | //    const int32_t* bias;
 | ||||||
|  | //    int32_t maxValue;
 | ||||||
|  | //    int32_t minValue;
 | ||||||
|  | //    float roundValuePos = 0.5f;
 | ||||||
|  | //    float roundValueNeg = -0.5f;
 | ||||||
|  | //};
 | ||||||
|  | 
 | ||||||
|  | asm_function _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain | ||||||
|  | //void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post);
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | // SystemV Auto: rdi: dst, rsi:src, rdx:weight, rcx:strides, r8: post | ||||||
|  | // Microsoft x64 Auto: rcx:C, rdx:A, r8:B, r9:parameter | ||||||
|  | pushq   %rbp | ||||||
|  | movq    %rsp, %rbp | ||||||
|  | 
 | ||||||
|  | #ifdef WIN32 | ||||||
|  | movq 48(%rsp), %r10 | ||||||
|  | pushq %rdi | ||||||
|  | pushq %rsi | ||||||
|  | pushq %r12 | ||||||
|  | pushq %r13 | ||||||
|  | movq %rcx, %rdi | ||||||
|  | movq %rdx, %rsi | ||||||
|  | movq %r8, %rdx | ||||||
|  | movq %r9, %rcx | ||||||
|  | movq %r10, %r9 | ||||||
|  | pushq   %r14 | ||||||
|  | pushq   %r15 | ||||||
|  | #else | ||||||
|  | pushq   %r12 | ||||||
|  | pushq   %r13 | ||||||
|  | pushq   %r14 | ||||||
|  | pushq   %r15 | ||||||
|  | movq %r8, %r9 | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | movq 8(%rcx), %r10 // dst_step | ||||||
|  | movq 16(%rcx), %r8 // dst_depth_quad | ||||||
|  | movq (%rcx), %rcx // src_depth_quad | ||||||
|  | movq (%r9), %r12 // scale | ||||||
|  | movq 8(%r9), %r15 // bias | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | // ymm0-ymm1: Src | ||||||
|  | // ymm2-ymm3: Weight | ||||||
|  | // ymm4-ymm7: TmpDst | ||||||
|  | // ymm8-ymm15: Dst Sum | ||||||
|  | 
 | ||||||
|  | // Last dst save to ymm8-ymm11 | ||||||
|  | 
 | ||||||
|  | cmpq $0, %r8 | ||||||
|  | je End | ||||||
|  | 
 | ||||||
|  | movq %rsi, %r13 | ||||||
|  | subq $64, %rsp | ||||||
|  | LoopDz: | ||||||
|  |     movq %rcx, %r11 | ||||||
|  |     movq %r13, %rsi | ||||||
|  |     movq %rdx, %r14 | ||||||
|  |     subq $1, %r11 | ||||||
|  |     vpmovsxbw (%rsi), %ymm0 | ||||||
|  |     vpmovsxbw 16(%rsi), %ymm1 | ||||||
|  |     vpmovsxbw (%rdx), %ymm2 | ||||||
|  |     vpmovsxbw 16(%rdx), %ymm3 | ||||||
|  | 
 | ||||||
|  |     vpmaddwd %ymm0, %ymm2, %ymm8 | ||||||
|  |     vpmaddwd %ymm0, %ymm3, %ymm9 | ||||||
|  |     vpmaddwd %ymm1, %ymm2, %ymm12 | ||||||
|  |     vpmaddwd %ymm1, %ymm3, %ymm13 | ||||||
|  |     vpmovsxbw 32(%rdx), %ymm2 | ||||||
|  |     vpmovsxbw 48(%rdx), %ymm3 | ||||||
|  | 
 | ||||||
|  |     vpmaddwd %ymm0, %ymm2, %ymm10 | ||||||
|  |     vpmaddwd %ymm0, %ymm3, %ymm11 | ||||||
|  |     vpmaddwd %ymm1, %ymm2, %ymm14 | ||||||
|  |     vpmaddwd %ymm1, %ymm3, %ymm15 | ||||||
|  |     addq $64, %rdx | ||||||
|  |     addq $64, %rsi | ||||||
|  | 
 | ||||||
|  |     testq %r11, %r11 | ||||||
|  |     je FirstLoopSzEnd | ||||||
|  | 
 | ||||||
|  |     FirstLoopSz: | ||||||
|  |         vpmovsxbw (%rsi), %ymm0 | ||||||
|  |         vpmovsxbw 16(%rsi), %ymm1 | ||||||
|  |         vpmovsxbw (%rdx), %ymm2 | ||||||
|  |         vpmovsxbw 16(%rdx), %ymm3 | ||||||
|  | 
 | ||||||
|  |         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||||
|  |         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||||
|  |         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||||
|  |         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||||
|  |         vpaddd %ymm4, %ymm8, %ymm8 | ||||||
|  |         vpaddd %ymm5, %ymm9, %ymm9 | ||||||
|  |         vpmovsxbw 32(%rdx), %ymm2 | ||||||
|  |         vpmovsxbw 48(%rdx), %ymm3 | ||||||
|  |         vpaddd %ymm6, %ymm12, %ymm12 | ||||||
|  |         vpaddd %ymm7, %ymm13, %ymm13 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||||
|  |         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||||
|  |         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||||
|  |         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||||
|  |         vpaddd %ymm4, %ymm10, %ymm10 | ||||||
|  |         vpaddd %ymm5, %ymm11, %ymm11 | ||||||
|  |         vpaddd %ymm6, %ymm14, %ymm14 | ||||||
|  |         vpaddd %ymm7, %ymm15, %ymm15 | ||||||
|  | 
 | ||||||
|  |         addq $64, %rdx | ||||||
|  |         addq $64, %rsi | ||||||
|  | 
 | ||||||
|  |         subq $1, %r11 | ||||||
|  |         testq %r11, %r11 | ||||||
|  |         jne FirstLoopSz | ||||||
|  | 
 | ||||||
|  |     FirstLoopSzEnd: | ||||||
|  |      | ||||||
|  |     vphaddd %ymm9, %ymm8, %ymm8 | ||||||
|  |     vphaddd %ymm11, %ymm10, %ymm10 | ||||||
|  |     vphaddd %ymm13, %ymm12, %ymm12 | ||||||
|  |     vphaddd %ymm15, %ymm14, %ymm14 | ||||||
|  | 
 | ||||||
|  |     vphaddd %ymm10, %ymm8, %ymm8 | ||||||
|  |     vphaddd %ymm14, %ymm12, %ymm9 | ||||||
|  | 
 | ||||||
|  |     vmovups %ymm8, (%rsp) | ||||||
|  |     vmovups %ymm9, 32(%rsp) | ||||||
|  | 
 | ||||||
|  |     movq %rcx, %r11 | ||||||
|  |     movq %r13, %rsi | ||||||
|  |     movq %r14, %rdx | ||||||
|  |     vpmovsxbw 32(%rsi), %ymm0 | ||||||
|  |     vpmovsxbw 48(%rsi), %ymm1 | ||||||
|  |     vpmovsxbw (%rdx), %ymm2 | ||||||
|  |     vpmovsxbw 16(%rdx), %ymm3 | ||||||
|  | 
 | ||||||
|  |     vpmaddwd %ymm0, %ymm2, %ymm8 | ||||||
|  |     vpmaddwd %ymm0, %ymm3, %ymm9 | ||||||
|  |     vpmaddwd %ymm1, %ymm2, %ymm12 | ||||||
|  |     vpmaddwd %ymm1, %ymm3, %ymm13 | ||||||
|  | 
 | ||||||
|  |     vpmovsxbw 32(%rdx), %ymm2 | ||||||
|  |     vpmovsxbw 48(%rdx), %ymm3 | ||||||
|  | 
 | ||||||
|  |     vpmaddwd %ymm0, %ymm2, %ymm10 | ||||||
|  |     vpmaddwd %ymm0, %ymm3, %ymm11 | ||||||
|  |     vpmaddwd %ymm1, %ymm2, %ymm14 | ||||||
|  |     vpmaddwd %ymm1, %ymm3, %ymm15 | ||||||
|  | 
 | ||||||
|  |     addq $64, %rdx | ||||||
|  |     addq $64, %rsi | ||||||
|  | 
 | ||||||
|  |     subq $1, %r11 | ||||||
|  |     testq %r11, %r11 | ||||||
|  |     je SecondLoopSzEnd | ||||||
|  | 
 | ||||||
|  |     SecondLoopSz: | ||||||
|  |         vpmovsxbw 32(%rsi), %ymm0 | ||||||
|  |         vpmovsxbw 48(%rsi), %ymm1 | ||||||
|  |         vpmovsxbw (%rdx), %ymm2 | ||||||
|  |         vpmovsxbw 16(%rdx), %ymm3 | ||||||
|  | 
 | ||||||
|  |         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||||
|  |         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||||
|  |         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||||
|  |         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||||
|  |         vpaddd %ymm4, %ymm8, %ymm8 | ||||||
|  |         vpaddd %ymm5, %ymm9, %ymm9 | ||||||
|  |         vpmovsxbw 32(%rdx), %ymm2 | ||||||
|  |         vpmovsxbw 48(%rdx), %ymm3 | ||||||
|  |         vpaddd %ymm6, %ymm12, %ymm12 | ||||||
|  |         vpaddd %ymm7, %ymm13, %ymm13 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||||
|  |         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||||
|  |         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||||
|  |         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||||
|  |         vpaddd %ymm4, %ymm10, %ymm10 | ||||||
|  |         vpaddd %ymm5, %ymm11, %ymm11 | ||||||
|  |         vpaddd %ymm6, %ymm14, %ymm14 | ||||||
|  |         vpaddd %ymm7, %ymm15, %ymm15 | ||||||
|  | 
 | ||||||
|  |         addq $64, %rdx | ||||||
|  |         addq $64, %rsi | ||||||
|  | 
 | ||||||
|  |         subq $1, %r11 | ||||||
|  |         testq %r11, %r11 | ||||||
|  |         jne SecondLoopSz | ||||||
|  |     SecondLoopSzEnd: | ||||||
|  | 
 | ||||||
|  |     vphaddd %ymm9, %ymm8, %ymm8 | ||||||
|  |     vphaddd %ymm11, %ymm10, %ymm10 | ||||||
|  |     vphaddd %ymm13, %ymm12, %ymm12 | ||||||
|  |     vphaddd %ymm15, %ymm14, %ymm14 | ||||||
|  | 
 | ||||||
|  |     vphaddd %ymm10, %ymm8, %ymm10 | ||||||
|  |     vphaddd %ymm14, %ymm12, %ymm11 | ||||||
|  | 
 | ||||||
|  |     vmovups (%rsp), %ymm8 | ||||||
|  |     vmovups 32(%rsp), %ymm9 | ||||||
|  | 
 | ||||||
|  |     Last: | ||||||
|  | .macro TRANSPOSE x0, x1, x2, x3 | ||||||
|  |     // 32 = 0 + 16 * 2: frist 128 x0_lo, second 128 x1_lo | ||||||
|  |     // 49 = 1 + 16 * 3: frist 128 x0_hi, second 128 x1_hi | ||||||
|  |     vperm2f128 $32, \x1, \x0, \x2 | ||||||
|  |     vperm2f128 $49, \x1, \x0, \x3 | ||||||
|  | .endm | ||||||
|  |     TRANSPOSE %ymm8, %ymm10, %ymm0, %ymm1 | ||||||
|  |     TRANSPOSE %ymm9, %ymm11, %ymm2, %ymm3 | ||||||
|  | 
 | ||||||
|  |     vpaddd %ymm0, %ymm1, %ymm0 | ||||||
|  |     vpaddd %ymm2, %ymm3, %ymm2 | ||||||
|  | 
 | ||||||
|  |     vbroadcastf128 (%r12), %ymm8 | ||||||
|  |     vbroadcastf128 (%r15), %ymm9 | ||||||
|  | 
 | ||||||
|  |     vpaddd %ymm9, %ymm0, %ymm0 | ||||||
|  |     vpaddd %ymm9, %ymm2, %ymm2 | ||||||
|  | 
 | ||||||
|  |     vcvtdq2ps %ymm0, %ymm0 | ||||||
|  |     vcvtdq2ps %ymm2, %ymm2 | ||||||
|  | 
 | ||||||
|  |     vmulps %ymm8, %ymm0, %ymm0 | ||||||
|  |     vmulps %ymm8, %ymm2, %ymm2 | ||||||
|  |     // zero | ||||||
|  |     vxorps %ymm13, %ymm13, %ymm13 | ||||||
|  | 
 | ||||||
|  |     vbroadcastss 24(%r9), %ymm14 | ||||||
|  |     vbroadcastss 28(%r9), %ymm15 | ||||||
|  |     vbroadcastss 16(%r9), %ymm10 | ||||||
|  |     vbroadcastss 20(%r9), %ymm11 | ||||||
|  | 
 | ||||||
|  |     // Round | ||||||
|  |     vcmpltps %ymm13, %ymm0, %ymm4 | ||||||
|  |     vcmpltps %ymm13, %ymm2, %ymm5 | ||||||
|  | 
 | ||||||
|  |     vblendvps %ymm4, %ymm15, %ymm14, %ymm4 | ||||||
|  |     vblendvps %ymm5, %ymm15, %ymm14, %ymm5 | ||||||
|  | 
 | ||||||
|  |     vaddps %ymm0, %ymm4, %ymm0 | ||||||
|  |     vaddps %ymm2, %ymm5, %ymm2 | ||||||
|  | 
 | ||||||
|  |     // 3: ROUND to Zero | ||||||
|  |     vroundps $3, %ymm0, %ymm0 | ||||||
|  |     vroundps $3, %ymm2, %ymm2 | ||||||
|  |     vcvtps2dq %ymm0, %ymm0 | ||||||
|  |     vcvtps2dq %ymm2, %ymm2 | ||||||
|  | 
 | ||||||
|  |     vpminsd %ymm10, %ymm0, %ymm0 | ||||||
|  |     vpminsd %ymm10, %ymm2, %ymm2 | ||||||
|  | 
 | ||||||
|  |     vpmaxsd %ymm11, %ymm0, %ymm0 | ||||||
|  |     vpmaxsd %ymm11, %ymm2, %ymm2 | ||||||
|  | 
 | ||||||
|  |     vpackssdw %ymm2, %ymm0, %ymm0 | ||||||
|  |     vperm2f128 $1, %ymm0, %ymm0, %ymm1 | ||||||
|  |     vpacksswb %ymm1, %ymm0, %ymm0 | ||||||
|  | 
 | ||||||
|  |     addq $16, %r12 | ||||||
|  |     addq $16, %r15 | ||||||
|  | 
 | ||||||
|  |     vmovups %xmm0, (%rdi) | ||||||
|  |     addq %r10, %rdi | ||||||
|  | 
 | ||||||
|  |     subq $1, %r8 | ||||||
|  |     testq %r8, %r8 | ||||||
|  |     jne LoopDz | ||||||
|  | addq $64, %rsp | ||||||
|  | 
 | ||||||
|  | End: | ||||||
|  | 
 | ||||||
|  | #ifdef WIN32 | ||||||
|  | popq    %r15 | ||||||
|  | popq    %r14 | ||||||
|  | popq    %r13 | ||||||
|  | popq    %r12 | ||||||
|  | popq    %rsi | ||||||
|  | popq    %rdi | ||||||
|  | popq    %rbp | ||||||
|  | #else | ||||||
|  | popq    %r15 | ||||||
|  | popq    %r14 | ||||||
|  | popq    %r13 | ||||||
|  | popq    %r12 | ||||||
|  | popq    %rbp | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | // FIXME: if don't vzeroall, it will cause other op slow | ||||||
|  | vzeroall | ||||||
|  | retq | ||||||
|  | 
 | ||||||
		Loading…
	
		Reference in New Issue