[PATCH 25/78] [MNN:Speed] Add asm for avx int8

This commit is contained in:
xiaying 2020-11-04 19:03:02 +08:00 committed by xiaying
parent 70016820a0
commit 1bd8d27131
3 changed files with 324 additions and 1 deletions

View File

@ -63,6 +63,8 @@ struct QuanPostTreatParameters {
const int32_t* bias;
int32_t maxValue;
int32_t minValue;
float roundValuePos = 0.5f;
float roundValueNeg = -0.5f;
};
void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post);
void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post);

View File

@ -203,8 +203,19 @@ void AVX2GemmPostTreat(float* C, size_t eSize, const size_t* parameter, const fl
}
}
}
#ifdef MNN_X86_USE_ASM
extern "C" {
void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post);
}
#endif
void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post) {
#ifdef MNN_X86_USE_ASM
size_t strides[3];
strides[0] = src_depth_quad;
strides[1] = dst_step;
strides[2] = dst_depth_quad;
_AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(dst, src, weight, strides, post);
#else
const auto dst_step_tmp = dst_step / sizeof(int8_t);
__m128 zero128 = _mm_set1_ps(0.0f);
__m128 minValue = _mm_set1_ps(post->minValue);
@ -356,4 +367,5 @@ auto d##i = _mm_add_epi32(d##i##0, d##i##1);
d0 = _mm_packs_epi16(d0, d2);
_mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0));
}
#endif
}

View File

@ -0,0 +1,309 @@
//
// _AVX_MNNGemmInt8AddBiasScale_16x4_Unit.S
// MNN
//
// Created by MNN on 2020/11/04.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "../MNNAsmGlobal.h"
.text
.align 4
//struct QuanPostTreatParameters {
// const float* scale;
// const int32_t* bias;
// int32_t maxValue;
// int32_t minValue;
// float roundValuePos = 0.5f;
// float roundValueNeg = -0.5f;
//};
asm_function _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain
//void _AVX_MNNGemmInt8AddBiasScale_16x4_UnitMain(int8_t* dst, const int8_t* src, const int8_t* weight, const size_t* strides, const QuanPostTreatParameters* post);
// SystemV Auto: rdi: dst, rsi:src, rdx:weight, rcx:strides, r8: post
// Microsoft x64 Auto: rcx:C, rdx:A, r8:B, r9:parameter
pushq %rbp
movq %rsp, %rbp
#ifdef WIN32
movq 48(%rsp), %r10
pushq %rdi
pushq %rsi
pushq %r12
pushq %r13
movq %rcx, %rdi
movq %rdx, %rsi
movq %r8, %rdx
movq %r9, %rcx
movq %r10, %r9
pushq %r14
pushq %r15
#else
pushq %r12
pushq %r13
pushq %r14
pushq %r15
movq %r8, %r9
#endif
movq 8(%rcx), %r10 // dst_step
movq 16(%rcx), %r8 // dst_depth_quad
movq (%rcx), %rcx // src_depth_quad
movq (%r9), %r12 // scale
movq 8(%r9), %r15 // bias
// ymm0-ymm1: Src
// ymm2-ymm3: Weight
// ymm4-ymm7: TmpDst
// ymm8-ymm15: Dst Sum
// Last dst save to ymm8-ymm11
cmpq $0, %r8
je End
movq %rsi, %r13
subq $64, %rsp
LoopDz:
movq %rcx, %r11
movq %r13, %rsi
movq %rdx, %r14
subq $1, %r11
vpmovsxbw (%rsi), %ymm0
vpmovsxbw 16(%rsi), %ymm1
vpmovsxbw (%rdx), %ymm2
vpmovsxbw 16(%rdx), %ymm3
vpmaddwd %ymm0, %ymm2, %ymm8
vpmaddwd %ymm0, %ymm3, %ymm9
vpmaddwd %ymm1, %ymm2, %ymm12
vpmaddwd %ymm1, %ymm3, %ymm13
vpmovsxbw 32(%rdx), %ymm2
vpmovsxbw 48(%rdx), %ymm3
vpmaddwd %ymm0, %ymm2, %ymm10
vpmaddwd %ymm0, %ymm3, %ymm11
vpmaddwd %ymm1, %ymm2, %ymm14
vpmaddwd %ymm1, %ymm3, %ymm15
addq $64, %rdx
addq $64, %rsi
testq %r11, %r11
je FirstLoopSzEnd
FirstLoopSz:
vpmovsxbw (%rsi), %ymm0
vpmovsxbw 16(%rsi), %ymm1
vpmovsxbw (%rdx), %ymm2
vpmovsxbw 16(%rdx), %ymm3
vpmaddwd %ymm0, %ymm2, %ymm4
vpmaddwd %ymm0, %ymm3, %ymm5
vpmaddwd %ymm1, %ymm2, %ymm6
vpmaddwd %ymm1, %ymm3, %ymm7
vpaddd %ymm4, %ymm8, %ymm8
vpaddd %ymm5, %ymm9, %ymm9
vpmovsxbw 32(%rdx), %ymm2
vpmovsxbw 48(%rdx), %ymm3
vpaddd %ymm6, %ymm12, %ymm12
vpaddd %ymm7, %ymm13, %ymm13
vpmaddwd %ymm0, %ymm2, %ymm4
vpmaddwd %ymm0, %ymm3, %ymm5
vpmaddwd %ymm1, %ymm2, %ymm6
vpmaddwd %ymm1, %ymm3, %ymm7
vpaddd %ymm4, %ymm10, %ymm10
vpaddd %ymm5, %ymm11, %ymm11
vpaddd %ymm6, %ymm14, %ymm14
vpaddd %ymm7, %ymm15, %ymm15
addq $64, %rdx
addq $64, %rsi
subq $1, %r11
testq %r11, %r11
jne FirstLoopSz
FirstLoopSzEnd:
vphaddd %ymm9, %ymm8, %ymm8
vphaddd %ymm11, %ymm10, %ymm10
vphaddd %ymm13, %ymm12, %ymm12
vphaddd %ymm15, %ymm14, %ymm14
vphaddd %ymm10, %ymm8, %ymm8
vphaddd %ymm14, %ymm12, %ymm9
vmovups %ymm8, (%rsp)
vmovups %ymm9, 32(%rsp)
movq %rcx, %r11
movq %r13, %rsi
movq %r14, %rdx
vpmovsxbw 32(%rsi), %ymm0
vpmovsxbw 48(%rsi), %ymm1
vpmovsxbw (%rdx), %ymm2
vpmovsxbw 16(%rdx), %ymm3
vpmaddwd %ymm0, %ymm2, %ymm8
vpmaddwd %ymm0, %ymm3, %ymm9
vpmaddwd %ymm1, %ymm2, %ymm12
vpmaddwd %ymm1, %ymm3, %ymm13
vpmovsxbw 32(%rdx), %ymm2
vpmovsxbw 48(%rdx), %ymm3
vpmaddwd %ymm0, %ymm2, %ymm10
vpmaddwd %ymm0, %ymm3, %ymm11
vpmaddwd %ymm1, %ymm2, %ymm14
vpmaddwd %ymm1, %ymm3, %ymm15
addq $64, %rdx
addq $64, %rsi
subq $1, %r11
testq %r11, %r11
je SecondLoopSzEnd
SecondLoopSz:
vpmovsxbw 32(%rsi), %ymm0
vpmovsxbw 48(%rsi), %ymm1
vpmovsxbw (%rdx), %ymm2
vpmovsxbw 16(%rdx), %ymm3
vpmaddwd %ymm0, %ymm2, %ymm4
vpmaddwd %ymm0, %ymm3, %ymm5
vpmaddwd %ymm1, %ymm2, %ymm6
vpmaddwd %ymm1, %ymm3, %ymm7
vpaddd %ymm4, %ymm8, %ymm8
vpaddd %ymm5, %ymm9, %ymm9
vpmovsxbw 32(%rdx), %ymm2
vpmovsxbw 48(%rdx), %ymm3
vpaddd %ymm6, %ymm12, %ymm12
vpaddd %ymm7, %ymm13, %ymm13
vpmaddwd %ymm0, %ymm2, %ymm4
vpmaddwd %ymm0, %ymm3, %ymm5
vpmaddwd %ymm1, %ymm2, %ymm6
vpmaddwd %ymm1, %ymm3, %ymm7
vpaddd %ymm4, %ymm10, %ymm10
vpaddd %ymm5, %ymm11, %ymm11
vpaddd %ymm6, %ymm14, %ymm14
vpaddd %ymm7, %ymm15, %ymm15
addq $64, %rdx
addq $64, %rsi
subq $1, %r11
testq %r11, %r11
jne SecondLoopSz
SecondLoopSzEnd:
vphaddd %ymm9, %ymm8, %ymm8
vphaddd %ymm11, %ymm10, %ymm10
vphaddd %ymm13, %ymm12, %ymm12
vphaddd %ymm15, %ymm14, %ymm14
vphaddd %ymm10, %ymm8, %ymm10
vphaddd %ymm14, %ymm12, %ymm11
vmovups (%rsp), %ymm8
vmovups 32(%rsp), %ymm9
Last:
.macro TRANSPOSE x0, x1, x2, x3
// 32 = 0 + 16 * 2: frist 128 x0_lo, second 128 x1_lo
// 49 = 1 + 16 * 3: frist 128 x0_hi, second 128 x1_hi
vperm2f128 $32, \x1, \x0, \x2
vperm2f128 $49, \x1, \x0, \x3
.endm
TRANSPOSE %ymm8, %ymm10, %ymm0, %ymm1
TRANSPOSE %ymm9, %ymm11, %ymm2, %ymm3
vpaddd %ymm0, %ymm1, %ymm0
vpaddd %ymm2, %ymm3, %ymm2
vbroadcastf128 (%r12), %ymm8
vbroadcastf128 (%r15), %ymm9
vpaddd %ymm9, %ymm0, %ymm0
vpaddd %ymm9, %ymm2, %ymm2
vcvtdq2ps %ymm0, %ymm0
vcvtdq2ps %ymm2, %ymm2
vmulps %ymm8, %ymm0, %ymm0
vmulps %ymm8, %ymm2, %ymm2
// zero
vxorps %ymm13, %ymm13, %ymm13
vbroadcastss 24(%r9), %ymm14
vbroadcastss 28(%r9), %ymm15
vbroadcastss 16(%r9), %ymm10
vbroadcastss 20(%r9), %ymm11
// Round
vcmpltps %ymm13, %ymm0, %ymm4
vcmpltps %ymm13, %ymm2, %ymm5
vblendvps %ymm4, %ymm15, %ymm14, %ymm4
vblendvps %ymm5, %ymm15, %ymm14, %ymm5
vaddps %ymm0, %ymm4, %ymm0
vaddps %ymm2, %ymm5, %ymm2
// 3: ROUND to Zero
vroundps $3, %ymm0, %ymm0
vroundps $3, %ymm2, %ymm2
vcvtps2dq %ymm0, %ymm0
vcvtps2dq %ymm2, %ymm2
vpminsd %ymm10, %ymm0, %ymm0
vpminsd %ymm10, %ymm2, %ymm2
vpmaxsd %ymm11, %ymm0, %ymm0
vpmaxsd %ymm11, %ymm2, %ymm2
vpackssdw %ymm2, %ymm0, %ymm0
vperm2f128 $1, %ymm0, %ymm0, %ymm1
vpacksswb %ymm1, %ymm0, %ymm0
addq $16, %r12
addq $16, %r15
vmovups %xmm0, (%rdi)
addq %r10, %rdi
subq $1, %r8
testq %r8, %r8
jne LoopDz
addq $64, %rsp
End:
#ifdef WIN32
popq %r15
popq %r14
popq %r13
popq %r12
popq %rsi
popq %rdi
popq %rbp
#else
popq %r15
popq %r14
popq %r13
popq %r12
popq %rbp
#endif
// FIXME: if don't vzeroall, it will cause other op slow
vzeroall
retq