mirror of https://github.com/alibaba/MNN.git
				
				
				
			[PATCH 25/78] [MNN:Speed] Add asm for avx int8
This commit is contained in:
		
							parent
							
								
									70016820a0
								
							
						
					
					
						commit
						1bd8d27131
					
				|  | @ -63,6 +63,8 @@ struct QuanPostTreatParameters { | |||
|     const int32_t* bias; | ||||
|     int32_t maxValue; | ||||
|     int32_t minValue; | ||||
|     float roundValuePos = 0.5f; | ||||
|     float roundValueNeg = -0.5f; | ||||
| }; | ||||
| void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); | ||||
| void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post); | ||||
|  |  | |||
|  | @ -203,8 +203,19 @@ void AVX2GemmPostTreat(float* C, size_t eSize, const size_t* parameter, const fl | |||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #ifdef MNN_X86_USE_ASM | ||||
| extern "C" { | ||||
| void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post); | ||||
| } | ||||
| #endif | ||||
| void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post) { | ||||
| #ifdef MNN_X86_USE_ASM | ||||
|     size_t strides[3]; | ||||
|     strides[0] = src_depth_quad; | ||||
|     strides[1] = dst_step; | ||||
|     strides[2] = dst_depth_quad; | ||||
|     _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(dst, src, weight, strides, post); | ||||
| #else | ||||
|     const auto dst_step_tmp = dst_step / sizeof(int8_t); | ||||
|     __m128 zero128 = _mm_set1_ps(0.0f); | ||||
|     __m128 minValue = _mm_set1_ps(post->minValue); | ||||
|  | @ -356,4 +367,5 @@ auto d##i = _mm_add_epi32(d##i##0, d##i##1); | |||
|         d0 = _mm_packs_epi16(d0, d2); | ||||
|         _mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0)); | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,309 @@ | |||
| // | ||||
| //  _AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S | ||||
| //  MNN | ||||
| // | ||||
| //  Created by MNN on 2020/11/04. | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited | ||||
| // | ||||
| 
 | ||||
| #include "../MNNAsmGlobal.h" | ||||
| .text | ||||
| .align 4
 | ||||
| 
 | ||||
| //struct QuanPostTreatParameters { | ||||
| //    const float* scale;
 | ||||
| //    const int32_t* bias;
 | ||||
| //    int32_t maxValue;
 | ||||
| //    int32_t minValue;
 | ||||
| //    float roundValuePos = 0.5f;
 | ||||
| //    float roundValueNeg = -0.5f;
 | ||||
| //};
 | ||||
| 
 | ||||
| asm_function _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain | ||||
| //void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post);
 | ||||
| 
 | ||||
| 
 | ||||
| // SystemV Auto: rdi: dst, rsi:src, rdx:weight, rcx:strides, r8: post | ||||
| // Microsoft x64 Auto: rcx:C, rdx:A, r8:B, r9:parameter | ||||
| pushq   %rbp | ||||
| movq    %rsp, %rbp | ||||
| 
 | ||||
| #ifdef WIN32 | ||||
| movq 48(%rsp), %r10 | ||||
| pushq %rdi | ||||
| pushq %rsi | ||||
| pushq %r12 | ||||
| pushq %r13 | ||||
| movq %rcx, %rdi | ||||
| movq %rdx, %rsi | ||||
| movq %r8, %rdx | ||||
| movq %r9, %rcx | ||||
| movq %r10, %r9 | ||||
| pushq   %r14 | ||||
| pushq   %r15 | ||||
| #else | ||||
| pushq   %r12 | ||||
| pushq   %r13 | ||||
| pushq   %r14 | ||||
| pushq   %r15 | ||||
| movq %r8, %r9 | ||||
| #endif | ||||
| 
 | ||||
| movq 8(%rcx), %r10 // dst_step | ||||
| movq 16(%rcx), %r8 // dst_depth_quad | ||||
| movq (%rcx), %rcx // src_depth_quad | ||||
| movq (%r9), %r12 // scale | ||||
| movq 8(%r9), %r15 // bias | ||||
| 
 | ||||
| 
 | ||||
| // ymm0-ymm1: Src | ||||
| // ymm2-ymm3: Weight | ||||
| // ymm4-ymm7: TmpDst | ||||
| // ymm8-ymm15: Dst Sum | ||||
| 
 | ||||
| // Last dst save to ymm8-ymm11 | ||||
| 
 | ||||
| cmpq $0, %r8 | ||||
| je End | ||||
| 
 | ||||
| movq %rsi, %r13 | ||||
| subq $64, %rsp | ||||
| LoopDz: | ||||
|     movq %rcx, %r11 | ||||
|     movq %r13, %rsi | ||||
|     movq %rdx, %r14 | ||||
|     subq $1, %r11 | ||||
|     vpmovsxbw (%rsi), %ymm0 | ||||
|     vpmovsxbw 16(%rsi), %ymm1 | ||||
|     vpmovsxbw (%rdx), %ymm2 | ||||
|     vpmovsxbw 16(%rdx), %ymm3 | ||||
| 
 | ||||
|     vpmaddwd %ymm0, %ymm2, %ymm8 | ||||
|     vpmaddwd %ymm0, %ymm3, %ymm9 | ||||
|     vpmaddwd %ymm1, %ymm2, %ymm12 | ||||
|     vpmaddwd %ymm1, %ymm3, %ymm13 | ||||
|     vpmovsxbw 32(%rdx), %ymm2 | ||||
|     vpmovsxbw 48(%rdx), %ymm3 | ||||
| 
 | ||||
|     vpmaddwd %ymm0, %ymm2, %ymm10 | ||||
|     vpmaddwd %ymm0, %ymm3, %ymm11 | ||||
|     vpmaddwd %ymm1, %ymm2, %ymm14 | ||||
|     vpmaddwd %ymm1, %ymm3, %ymm15 | ||||
|     addq $64, %rdx | ||||
|     addq $64, %rsi | ||||
| 
 | ||||
|     testq %r11, %r11 | ||||
|     je FirstLoopSzEnd | ||||
| 
 | ||||
|     FirstLoopSz: | ||||
|         vpmovsxbw (%rsi), %ymm0 | ||||
|         vpmovsxbw 16(%rsi), %ymm1 | ||||
|         vpmovsxbw (%rdx), %ymm2 | ||||
|         vpmovsxbw 16(%rdx), %ymm3 | ||||
| 
 | ||||
|         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||
|         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||
|         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||
|         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||
|         vpaddd %ymm4, %ymm8, %ymm8 | ||||
|         vpaddd %ymm5, %ymm9, %ymm9 | ||||
|         vpmovsxbw 32(%rdx), %ymm2 | ||||
|         vpmovsxbw 48(%rdx), %ymm3 | ||||
|         vpaddd %ymm6, %ymm12, %ymm12 | ||||
|         vpaddd %ymm7, %ymm13, %ymm13 | ||||
| 
 | ||||
| 
 | ||||
|         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||
|         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||
|         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||
|         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||
|         vpaddd %ymm4, %ymm10, %ymm10 | ||||
|         vpaddd %ymm5, %ymm11, %ymm11 | ||||
|         vpaddd %ymm6, %ymm14, %ymm14 | ||||
|         vpaddd %ymm7, %ymm15, %ymm15 | ||||
| 
 | ||||
|         addq $64, %rdx | ||||
|         addq $64, %rsi | ||||
| 
 | ||||
|         subq $1, %r11 | ||||
|         testq %r11, %r11 | ||||
|         jne FirstLoopSz | ||||
| 
 | ||||
|     FirstLoopSzEnd: | ||||
|      | ||||
|     vphaddd %ymm9, %ymm8, %ymm8 | ||||
|     vphaddd %ymm11, %ymm10, %ymm10 | ||||
|     vphaddd %ymm13, %ymm12, %ymm12 | ||||
|     vphaddd %ymm15, %ymm14, %ymm14 | ||||
| 
 | ||||
|     vphaddd %ymm10, %ymm8, %ymm8 | ||||
|     vphaddd %ymm14, %ymm12, %ymm9 | ||||
| 
 | ||||
|     vmovups %ymm8, (%rsp) | ||||
|     vmovups %ymm9, 32(%rsp) | ||||
| 
 | ||||
|     movq %rcx, %r11 | ||||
|     movq %r13, %rsi | ||||
|     movq %r14, %rdx | ||||
|     vpmovsxbw 32(%rsi), %ymm0 | ||||
|     vpmovsxbw 48(%rsi), %ymm1 | ||||
|     vpmovsxbw (%rdx), %ymm2 | ||||
|     vpmovsxbw 16(%rdx), %ymm3 | ||||
| 
 | ||||
|     vpmaddwd %ymm0, %ymm2, %ymm8 | ||||
|     vpmaddwd %ymm0, %ymm3, %ymm9 | ||||
|     vpmaddwd %ymm1, %ymm2, %ymm12 | ||||
|     vpmaddwd %ymm1, %ymm3, %ymm13 | ||||
| 
 | ||||
|     vpmovsxbw 32(%rdx), %ymm2 | ||||
|     vpmovsxbw 48(%rdx), %ymm3 | ||||
| 
 | ||||
|     vpmaddwd %ymm0, %ymm2, %ymm10 | ||||
|     vpmaddwd %ymm0, %ymm3, %ymm11 | ||||
|     vpmaddwd %ymm1, %ymm2, %ymm14 | ||||
|     vpmaddwd %ymm1, %ymm3, %ymm15 | ||||
| 
 | ||||
|     addq $64, %rdx | ||||
|     addq $64, %rsi | ||||
| 
 | ||||
|     subq $1, %r11 | ||||
|     testq %r11, %r11 | ||||
|     je SecondLoopSzEnd | ||||
| 
 | ||||
|     SecondLoopSz: | ||||
|         vpmovsxbw 32(%rsi), %ymm0 | ||||
|         vpmovsxbw 48(%rsi), %ymm1 | ||||
|         vpmovsxbw (%rdx), %ymm2 | ||||
|         vpmovsxbw 16(%rdx), %ymm3 | ||||
| 
 | ||||
|         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||
|         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||
|         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||
|         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||
|         vpaddd %ymm4, %ymm8, %ymm8 | ||||
|         vpaddd %ymm5, %ymm9, %ymm9 | ||||
|         vpmovsxbw 32(%rdx), %ymm2 | ||||
|         vpmovsxbw 48(%rdx), %ymm3 | ||||
|         vpaddd %ymm6, %ymm12, %ymm12 | ||||
|         vpaddd %ymm7, %ymm13, %ymm13 | ||||
| 
 | ||||
| 
 | ||||
|         vpmaddwd %ymm0, %ymm2, %ymm4 | ||||
|         vpmaddwd %ymm0, %ymm3, %ymm5 | ||||
|         vpmaddwd %ymm1, %ymm2, %ymm6 | ||||
|         vpmaddwd %ymm1, %ymm3, %ymm7 | ||||
|         vpaddd %ymm4, %ymm10, %ymm10 | ||||
|         vpaddd %ymm5, %ymm11, %ymm11 | ||||
|         vpaddd %ymm6, %ymm14, %ymm14 | ||||
|         vpaddd %ymm7, %ymm15, %ymm15 | ||||
| 
 | ||||
|         addq $64, %rdx | ||||
|         addq $64, %rsi | ||||
| 
 | ||||
|         subq $1, %r11 | ||||
|         testq %r11, %r11 | ||||
|         jne SecondLoopSz | ||||
|     SecondLoopSzEnd: | ||||
| 
 | ||||
|     vphaddd %ymm9, %ymm8, %ymm8 | ||||
|     vphaddd %ymm11, %ymm10, %ymm10 | ||||
|     vphaddd %ymm13, %ymm12, %ymm12 | ||||
|     vphaddd %ymm15, %ymm14, %ymm14 | ||||
| 
 | ||||
|     vphaddd %ymm10, %ymm8, %ymm10 | ||||
|     vphaddd %ymm14, %ymm12, %ymm11 | ||||
| 
 | ||||
|     vmovups (%rsp), %ymm8 | ||||
|     vmovups 32(%rsp), %ymm9 | ||||
| 
 | ||||
|     Last: | ||||
| .macro TRANSPOSE x0, x1, x2, x3 | ||||
|     // 32 = 0 + 16 * 2: frist 128 x0_lo, second 128 x1_lo | ||||
|     // 49 = 1 + 16 * 3: frist 128 x0_hi, second 128 x1_hi | ||||
|     vperm2f128 $32, \x1, \x0, \x2 | ||||
|     vperm2f128 $49, \x1, \x0, \x3 | ||||
| .endm | ||||
|     TRANSPOSE %ymm8, %ymm10, %ymm0, %ymm1 | ||||
|     TRANSPOSE %ymm9, %ymm11, %ymm2, %ymm3 | ||||
| 
 | ||||
|     vpaddd %ymm0, %ymm1, %ymm0 | ||||
|     vpaddd %ymm2, %ymm3, %ymm2 | ||||
| 
 | ||||
|     vbroadcastf128 (%r12), %ymm8 | ||||
|     vbroadcastf128 (%r15), %ymm9 | ||||
| 
 | ||||
|     vpaddd %ymm9, %ymm0, %ymm0 | ||||
|     vpaddd %ymm9, %ymm2, %ymm2 | ||||
| 
 | ||||
|     vcvtdq2ps %ymm0, %ymm0 | ||||
|     vcvtdq2ps %ymm2, %ymm2 | ||||
| 
 | ||||
|     vmulps %ymm8, %ymm0, %ymm0 | ||||
|     vmulps %ymm8, %ymm2, %ymm2 | ||||
|     // zero | ||||
|     vxorps %ymm13, %ymm13, %ymm13 | ||||
| 
 | ||||
|     vbroadcastss 24(%r9), %ymm14 | ||||
|     vbroadcastss 28(%r9), %ymm15 | ||||
|     vbroadcastss 16(%r9), %ymm10 | ||||
|     vbroadcastss 20(%r9), %ymm11 | ||||
| 
 | ||||
|     // Round | ||||
|     vcmpltps %ymm13, %ymm0, %ymm4 | ||||
|     vcmpltps %ymm13, %ymm2, %ymm5 | ||||
| 
 | ||||
|     vblendvps %ymm4, %ymm15, %ymm14, %ymm4 | ||||
|     vblendvps %ymm5, %ymm15, %ymm14, %ymm5 | ||||
| 
 | ||||
|     vaddps %ymm0, %ymm4, %ymm0 | ||||
|     vaddps %ymm2, %ymm5, %ymm2 | ||||
| 
 | ||||
|     // 3: ROUND to Zero | ||||
|     vroundps $3, %ymm0, %ymm0 | ||||
|     vroundps $3, %ymm2, %ymm2 | ||||
|     vcvtps2dq %ymm0, %ymm0 | ||||
|     vcvtps2dq %ymm2, %ymm2 | ||||
| 
 | ||||
|     vpminsd %ymm10, %ymm0, %ymm0 | ||||
|     vpminsd %ymm10, %ymm2, %ymm2 | ||||
| 
 | ||||
|     vpmaxsd %ymm11, %ymm0, %ymm0 | ||||
|     vpmaxsd %ymm11, %ymm2, %ymm2 | ||||
| 
 | ||||
|     vpackssdw %ymm2, %ymm0, %ymm0 | ||||
|     vperm2f128 $1, %ymm0, %ymm0, %ymm1 | ||||
|     vpacksswb %ymm1, %ymm0, %ymm0 | ||||
| 
 | ||||
|     addq $16, %r12 | ||||
|     addq $16, %r15 | ||||
| 
 | ||||
|     vmovups %xmm0, (%rdi) | ||||
|     addq %r10, %rdi | ||||
| 
 | ||||
|     subq $1, %r8 | ||||
|     testq %r8, %r8 | ||||
|     jne LoopDz | ||||
| addq $64, %rsp | ||||
| 
 | ||||
| End: | ||||
| 
 | ||||
| #ifdef WIN32 | ||||
| popq    %r15 | ||||
| popq    %r14 | ||||
| popq    %r13 | ||||
| popq    %r12 | ||||
| popq    %rsi | ||||
| popq    %rdi | ||||
| popq    %rbp | ||||
| #else | ||||
| popq    %r15 | ||||
| popq    %r14 | ||||
| popq    %r13 | ||||
| popq    %r12 | ||||
| popq    %rbp | ||||
| #endif | ||||
| 
 | ||||
| // FIXME: if don't vzeroall, it will cause other op slow | ||||
| vzeroall | ||||
| retq | ||||
| 
 | ||||
		Loading…
	
		Reference in New Issue