[PATCH 22/28] [MNN:Speed] 8x8 Gemm and cache prefetch optimize

This commit is contained in:
hebin 2020-03-12 11:34:45 +08:00 committed by xiaying
parent cf2ddd36a0
commit c0cb82d9ab
15 changed files with 381 additions and 242 deletions

View File

@ -1,5 +1,5 @@
//
// MNNGemmFloatUnit_4.S
// MNNGemmFloatUnit.S
// MNN
//
// Created by MNN on 2019/02/04.
@ -13,8 +13,8 @@
.text
.align 5
asm_function MNNGemmFloatUnit_4
//void MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
asm_function MNNGemmFloatUnit
//void MNNGemmFloatUnit(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
//Auto:
//r0:dstOrigin, r1:src, r2: weight, r3:src_depth_quad

View File

@ -0,0 +1,282 @@
//
// MNNGemmFloatUnit.S
// MNN
//
// Created by MNN on 2019/02/04.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
asm_function MNNGemmFloatUnit
//void MNNGemmFloatUnit(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
//Auto
//x0: dst, x1:src, x2:weight, x3:src_depth_quad
//x4:dst_step, x5:dst_depth_quad, x6: weight_depth_offset
mov x12, #4 //sizeof(float)
mul x4, x12, x4
mul x6, x12, x6
add x11, x6, x3, LSL #6
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
cmp x5, #2
blt LoopDzExtra
LoopDz:
mov x8, x1
subs x9, x3, #1
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2]
add x2, x2, x11
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
fmul v16.4s, v8.4s, v0.s[0]
fmul v17.4s, v8.4s, v1.s[0]
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
fmul v18.4s, v8.4s, v2.s[0]
fmul v19.4s, v8.4s, v3.s[0]
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
sub x2, x2, x11
fmul v20.4s, v8.4s, v4.s[0]
fmul v21.4s, v8.4s, v5.s[0]
fmul v22.4s, v8.4s, v6.s[0]
fmul v23.4s, v8.4s, v7.s[0]
fmul v24.4s, v12.4s, v0.s[0]
fmul v25.4s, v12.4s, v1.s[0]
fmul v26.4s, v12.4s, v2.s[0]
fmul v27.4s, v12.4s, v3.s[0]
fmul v28.4s, v12.4s, v4.s[0]
fmul v29.4s, v12.4s, v5.s[0]
fmul v30.4s, v12.4s, v6.s[0]
fmul v31.4s, v12.4s, v7.s[0]
beq L8LoopZEnd
L8LoopZ:
add x2, x2, #128
prfm pldl1keep, [x2]
prfm pldl1keep, [x2, x11]
sub x2, x2, #128
prfm pldl1keep, [x8, #128]
prfm pldl1keep, [x8, #192]
fmla v16.4s, v9.4s, v0.s[1]
fmla v17.4s, v9.4s, v1.s[1]
fmla v18.4s, v9.4s, v2.s[1]
fmla v19.4s, v9.4s, v3.s[1]
fmla v20.4s, v9.4s, v4.s[1]
fmla v21.4s, v9.4s, v5.s[1]
fmla v22.4s, v9.4s, v6.s[1]
fmla v23.4s, v9.4s, v7.s[1]
fmla v24.4s, v13.4s, v0.s[1]
fmla v25.4s, v13.4s, v1.s[1]
fmla v26.4s, v13.4s, v2.s[1]
fmla v27.4s, v13.4s, v3.s[1]
fmla v28.4s, v13.4s, v4.s[1]
fmla v29.4s, v13.4s, v5.s[1]
fmla v30.4s, v13.4s, v6.s[1]
fmla v31.4s, v13.4s, v7.s[1]
fmla v16.4s, v10.4s, v0.s[2]
fmla v17.4s, v10.4s, v1.s[2]
fmla v18.4s, v10.4s, v2.s[2]
fmla v19.4s, v10.4s, v3.s[2]
fmla v20.4s, v10.4s, v4.s[2]
fmla v21.4s, v10.4s, v5.s[2]
fmla v22.4s, v10.4s, v6.s[2]
fmla v23.4s, v10.4s, v7.s[2]
fmla v24.4s, v14.4s, v0.s[2]
fmla v25.4s, v14.4s, v1.s[2]
fmla v26.4s, v14.4s, v2.s[2]
fmla v27.4s, v14.4s, v3.s[2]
fmla v28.4s, v14.4s, v4.s[2]
fmla v29.4s, v14.4s, v5.s[2]
fmla v30.4s, v14.4s, v6.s[2]
fmla v31.4s, v14.4s, v7.s[2]
fmla v16.4s, v11.4s, v0.s[3]
fmla v17.4s, v11.4s, v1.s[3]
fmla v18.4s, v11.4s, v2.s[3]
fmla v19.4s, v11.4s, v3.s[3]
fmla v20.4s, v11.4s, v4.s[3]
fmla v21.4s, v11.4s, v5.s[3]
fmla v22.4s, v11.4s, v6.s[3]
fmla v23.4s, v11.4s, v7.s[3]
fmla v24.4s, v15.4s, v0.s[3]
fmla v25.4s, v15.4s, v1.s[3]
fmla v26.4s, v15.4s, v2.s[3]
fmla v27.4s, v15.4s, v3.s[3]
fmla v28.4s, v15.4s, v4.s[3]
fmla v29.4s, v15.4s, v5.s[3]
fmla v30.4s, v15.4s, v6.s[3]
fmla v31.4s, v15.4s, v7.s[3]
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2]
add x2, x2, x11
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
fmla v16.4s, v8.4s, v0.s[0]
fmla v17.4s, v8.4s, v1.s[0]
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
fmla v18.4s, v8.4s, v2.s[0]
fmla v19.4s, v8.4s, v3.s[0]
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
sub x2, x2, x11
fmla v20.4s, v8.4s, v4.s[0]
fmla v21.4s, v8.4s, v5.s[0]
fmla v22.4s, v8.4s, v6.s[0]
fmla v23.4s, v8.4s, v7.s[0]
fmla v24.4s, v12.4s, v0.s[0]
fmla v25.4s, v12.4s, v1.s[0]
fmla v26.4s, v12.4s, v2.s[0]
fmla v27.4s, v12.4s, v3.s[0]
fmla v28.4s, v12.4s, v4.s[0]
fmla v29.4s, v12.4s, v5.s[0]
fmla v30.4s, v12.4s, v6.s[0]
fmla v31.4s, v12.4s, v7.s[0]
subs x9, x9, #1
bne L8LoopZ
L8LoopZEnd:
fmla v16.4s, v9.4s, v0.s[1]
fmla v17.4s, v9.4s, v1.s[1]
fmla v18.4s, v9.4s, v2.s[1]
fmla v19.4s, v9.4s, v3.s[1]
fmla v20.4s, v9.4s, v4.s[1]
fmla v21.4s, v9.4s, v5.s[1]
fmla v22.4s, v9.4s, v6.s[1]
fmla v23.4s, v9.4s, v7.s[1]
fmla v24.4s, v13.4s, v0.s[1]
fmla v25.4s, v13.4s, v1.s[1]
fmla v26.4s, v13.4s, v2.s[1]
fmla v27.4s, v13.4s, v3.s[1]
fmla v28.4s, v13.4s, v4.s[1]
fmla v29.4s, v13.4s, v5.s[1]
fmla v30.4s, v13.4s, v6.s[1]
fmla v31.4s, v13.4s, v7.s[1]
fmla v16.4s, v10.4s, v0.s[2]
fmla v17.4s, v10.4s, v1.s[2]
fmla v18.4s, v10.4s, v2.s[2]
fmla v19.4s, v10.4s, v3.s[2]
fmla v20.4s, v10.4s, v4.s[2]
fmla v21.4s, v10.4s, v5.s[2]
fmla v22.4s, v10.4s, v6.s[2]
fmla v23.4s, v10.4s, v7.s[2]
fmla v24.4s, v14.4s, v0.s[2]
fmla v25.4s, v14.4s, v1.s[2]
fmla v26.4s, v14.4s, v2.s[2]
fmla v27.4s, v14.4s, v3.s[2]
fmla v28.4s, v14.4s, v4.s[2]
fmla v29.4s, v14.4s, v5.s[2]
fmla v30.4s, v14.4s, v6.s[2]
fmla v31.4s, v14.4s, v7.s[2]
mov x12, x0
fmla v16.4s, v11.4s, v0.s[3]
fmla v17.4s, v11.4s, v1.s[3]
fmla v18.4s, v11.4s, v2.s[3]
fmla v19.4s, v11.4s, v3.s[3]
fmla v20.4s, v11.4s, v4.s[3]
fmla v21.4s, v11.4s, v5.s[3]
fmla v22.4s, v11.4s, v6.s[3]
fmla v23.4s, v11.4s, v7.s[3]
fmla v24.4s, v15.4s, v0.s[3]
fmla v25.4s, v15.4s, v1.s[3]
fmla v26.4s, v15.4s, v2.s[3]
fmla v27.4s, v15.4s, v3.s[3]
fmla v28.4s, v15.4s, v4.s[3]
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
fmla v29.4s, v15.4s, v5.s[3]
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
fmla v30.4s, v15.4s, v6.s[3]
add x0, x12, x4
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
add x2, x2, x11
fmla v31.4s, v15.4s, v7.s[3]
add x2, x2, x6
sub x5, x5, #2
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
add x0, x12, x4, LSL #1
cmp x5, #1
blt LoopDzEnd
bgt LoopDz
LoopDzExtra:
mov w11, #0
mov x8, x1
mov x9, x3
dup v16.4s, w11
dup v17.4s, w11
dup v18.4s, w11
dup v19.4s, w11
dup v20.4s, w11
dup v21.4s, w11
dup v22.4s, w11
dup v23.4s, w11
L4LoopZ:
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
fmla v16.4s, v8.4s, v0.s[0]
fmla v17.4s, v8.4s, v1.s[0]
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
fmla v18.4s, v8.4s, v2.s[0]
fmla v19.4s, v8.4s, v3.s[0]
fmla v20.4s, v8.4s, v4.s[0]
fmla v21.4s, v8.4s, v5.s[0]
fmla v22.4s, v8.4s, v6.s[0]
fmla v23.4s, v8.4s, v7.s[0]
fmla v16.4s, v9.4s, v0.s[1]
fmla v17.4s, v9.4s, v1.s[1]
fmla v18.4s, v9.4s, v2.s[1]
fmla v19.4s, v9.4s, v3.s[1]
fmla v20.4s, v9.4s, v4.s[1]
fmla v21.4s, v9.4s, v5.s[1]
fmla v22.4s, v9.4s, v6.s[1]
fmla v23.4s, v9.4s, v7.s[1]
fmla v16.4s, v10.4s, v0.s[2]
fmla v17.4s, v10.4s, v1.s[2]
fmla v18.4s, v10.4s, v2.s[2]
fmla v19.4s, v10.4s, v3.s[2]
fmla v20.4s, v10.4s, v4.s[2]
fmla v21.4s, v10.4s, v5.s[2]
fmla v22.4s, v10.4s, v6.s[2]
fmla v23.4s, v10.4s, v7.s[2]
fmla v16.4s, v11.4s, v0.s[3]
fmla v17.4s, v11.4s, v1.s[3]
fmla v18.4s, v11.4s, v2.s[3]
fmla v19.4s, v11.4s, v3.s[3]
fmla v20.4s, v11.4s, v4.s[3]
fmla v21.4s, v11.4s, v5.s[3]
fmla v22.4s, v11.4s, v6.s[3]
fmla v23.4s, v11.4s, v7.s[3]
subs x9, x9, #1
bne L4LoopZ
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
LoopDzEnd:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif

View File

@ -1,192 +0,0 @@
//
// MNNGemmFloatUnit_4.S
// MNN
//
// Created by MNN on 2019/02/04.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
asm_function MNNGemmFloatUnit_4
//void MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
//Auto
//x0: dst, x1:src, x2:weight, x3:src_depth_quad
//x4:dst_step, x5:dst_depth_quad, x6: weight_depth_offset
mov x12, #4//sizeof(float)
mul x4, x12, x4
mul x6, x12, x6
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
LoopDz:
mov x8, x1
subs x9, x3, #1
ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x2], #64
ld1 {v0.4s, v1.4s}, [x8], #32
fmul v18.4s, v14.4s, v0.s[0]
ld1 {v2.4s, v3.4s}, [x8], #32
fmul v19.4s, v14.4s, v1.s[0]
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
fmul v20.4s, v14.4s, v2.s[0]
fmul v21.4s, v14.4s, v3.s[0]
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x8], #64
fmul v22.4s, v14.4s, v4.s[0]
fmul v23.4s, v14.4s, v5.s[0]
ld1 {v12.4s, v13.4s}, [x8], #32
fmul v24.4s, v14.4s, v6.s[0]
fmul v25.4s, v14.4s, v7.s[0]
fmul v26.4s, v14.4s, v8.s[0]
fmul v27.4s, v14.4s, v9.s[0]
fmul v28.4s, v14.4s, v10.s[0]
fmul v29.4s, v14.4s, v11.s[0]
fmul v30.4s, v14.4s, v12.s[0]
fmul v31.4s, v14.4s, v13.s[0]
beq L14LoopZEnd
L14LoopZ:
fmla v18.4s, v15.4s, v0.s[1]
fmla v19.4s, v15.4s, v1.s[1]
fmla v20.4s, v15.4s, v2.s[1]
fmla v21.4s, v15.4s, v3.s[1]
fmla v22.4s, v15.4s, v4.s[1]
fmla v23.4s, v15.4s, v5.s[1]
fmla v24.4s, v15.4s, v6.s[1]
fmla v25.4s, v15.4s, v7.s[1]
fmla v26.4s, v15.4s, v8.s[1]
fmla v27.4s, v15.4s, v9.s[1]
fmla v28.4s, v15.4s, v10.s[1]
fmla v29.4s, v15.4s, v11.s[1]
fmla v30.4s, v15.4s, v12.s[1]
fmla v31.4s, v15.4s, v13.s[1]
fmla v18.4s, v16.4s, v0.s[2]
fmla v19.4s, v16.4s, v1.s[2]
fmla v20.4s, v16.4s, v2.s[2]
fmla v21.4s, v16.4s, v3.s[2]
fmla v22.4s, v16.4s, v4.s[2]
fmla v23.4s, v16.4s, v5.s[2]
fmla v24.4s, v16.4s, v6.s[2]
fmla v25.4s, v16.4s, v7.s[2]
fmla v26.4s, v16.4s, v8.s[2]
fmla v27.4s, v16.4s, v9.s[2]
fmla v28.4s, v16.4s, v10.s[2]
fmla v29.4s, v16.4s, v11.s[2]
fmla v30.4s, v16.4s, v12.s[2]
fmla v31.4s, v16.4s, v13.s[2]
fmla v18.4s, v17.4s, v0.s[3]
fmla v19.4s, v17.4s, v1.s[3]
fmla v20.4s, v17.4s, v2.s[3]
fmla v21.4s, v17.4s, v3.s[3]
fmla v22.4s, v17.4s, v4.s[3]
fmla v23.4s, v17.4s, v5.s[3]
fmla v24.4s, v17.4s, v6.s[3]
fmla v25.4s, v17.4s, v7.s[3]
fmla v26.4s, v17.4s, v8.s[3]
fmla v27.4s, v17.4s, v9.s[3]
fmla v28.4s, v17.4s, v10.s[3]
fmla v29.4s, v17.4s, v11.s[3]
fmla v30.4s, v17.4s, v12.s[3]
fmla v31.4s, v17.4s, v13.s[3]
ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x2], #64
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
fmla v18.4s, v14.4s, v0.s[0]
fmla v19.4s, v14.4s, v1.s[0]
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
fmla v20.4s, v14.4s, v2.s[0]
fmla v21.4s, v14.4s, v3.s[0]
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x8], #64
fmla v22.4s, v14.4s, v4.s[0]
fmla v23.4s, v14.4s, v5.s[0]
ld1 {v12.4s, v13.4s}, [x8], #32
fmla v24.4s, v14.4s, v6.s[0]
fmla v25.4s, v14.4s, v7.s[0]
fmla v26.4s, v14.4s, v8.s[0]
fmla v27.4s, v14.4s, v9.s[0]
fmla v28.4s, v14.4s, v10.s[0]
fmla v29.4s, v14.4s, v11.s[0]
fmla v30.4s, v14.4s, v12.s[0]
fmla v31.4s, v14.4s, v13.s[0]
subs x9, x9, #1
bne L14LoopZ
L14LoopZEnd:
fmla v18.4s, v15.4s, v0.s[1]
fmla v19.4s, v15.4s, v1.s[1]
fmla v20.4s, v15.4s, v2.s[1]
fmla v21.4s, v15.4s, v3.s[1]
fmla v22.4s, v15.4s, v4.s[1]
fmla v23.4s, v15.4s, v5.s[1]
fmla v24.4s, v15.4s, v6.s[1]
fmla v25.4s, v15.4s, v7.s[1]
fmla v26.4s, v15.4s, v8.s[1]
fmla v27.4s, v15.4s, v9.s[1]
fmla v28.4s, v15.4s, v10.s[1]
fmla v29.4s, v15.4s, v11.s[1]
fmla v30.4s, v15.4s, v12.s[1]
fmla v31.4s, v15.4s, v13.s[1]
fmla v18.4s, v16.4s, v0.s[2]
fmla v19.4s, v16.4s, v1.s[2]
fmla v20.4s, v16.4s, v2.s[2]
fmla v21.4s, v16.4s, v3.s[2]
fmla v22.4s, v16.4s, v4.s[2]
fmla v23.4s, v16.4s, v5.s[2]
fmla v24.4s, v16.4s, v6.s[2]
fmla v25.4s, v16.4s, v7.s[2]
fmla v26.4s, v16.4s, v8.s[2]
fmla v27.4s, v16.4s, v9.s[2]
fmla v28.4s, v16.4s, v10.s[2]
fmla v29.4s, v16.4s, v11.s[2]
fmla v30.4s, v16.4s, v12.s[2]
fmla v31.4s, v16.4s, v13.s[2]
mov x12, x0
fmla v18.4s, v17.4s, v0.s[3]
fmla v19.4s, v17.4s, v1.s[3]
fmla v20.4s, v17.4s, v2.s[3]
fmla v21.4s, v17.4s, v3.s[3]
fmla v22.4s, v17.4s, v4.s[3]
st1 {v18.4s, v19.4s}, [x0], #32
fmla v23.4s, v17.4s, v5.s[3]
fmla v24.4s, v17.4s, v6.s[3]
fmla v25.4s, v17.4s, v7.s[3]
fmla v26.4s, v17.4s, v8.s[3]
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
fmla v27.4s, v17.4s, v9.s[3]
fmla v28.4s, v17.4s, v10.s[3]
fmla v29.4s, v17.4s, v11.s[3]
fmla v30.4s, v17.4s, v12.s[3]
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
fmla v31.4s, v17.4s, v13.s[3]
add x2, x2, x6
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
subs x5, x5, #1
add x0, x12, x4
bne LoopDz
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif

View File

@ -19,45 +19,88 @@ asm_function MNNMatrixCopyUnit
//Auto: x0: C, x1:A, x2:cStride
//x3:aStride, x4:height
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov x12, #4 //sizeof(float)
mul x2, x12, x2
mul x3, x12, x3
mov x12, #4 // cache prefetch param
mul x5, x3, x12
subs x4, x4, #1
mov x8, x0
mov x9, x1
ld1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x1], #64
ld1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x1], #64
ld1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x1], #64
cmp x4, #4
blt L1Loop
beq LoopYEnd
L4Loop:
add x9, x1, x5
prfm pldl1keep, [x9]
prfm pldl1keep, [x9, #64]
add x9, x9, x3
prfm pldl1keep, [x9]
prfm pldl1keep, [x9, #64]
add x9, x9, x3
prfm pldl1keep, [x9]
prfm pldl1keep, [x9, #64]
add x9, x9, x3
prfm pldl1keep, [x9]
prfm pldl1keep, [x9, #64]
LoopY:
// Unit = 14 for arm64a
st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x0], #64
st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x0], #64
ld1 {v30.4s, v31.4s}, [x1]
st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x0], #64
st1 {v30.4s, v31.4s}, [x0]
mov x9, x1
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
add x1, x9, x3
mov x9, x1
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
add x1, x9, x3
mov x9, x1
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64
add x1, x9, x3
mov x9, x1
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x1], #64
add x1, x9, x3
mov x8, x0
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
add x0, x8, x2
mov x8, x0
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
add x0, x8, x2
mov x8, x0
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
add x0, x8, x2
mov x8, x0
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
add x0, x8, x2
subs x4, x4, #4
cmp x4, #3
bgt L4Loop
cbz x4, LoopEnd
L1Loop:
mov x9, x1
ld1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x1], #64
ld1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x1], #64
ld1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x1], #64
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
mov x8, x0
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
add x1, x9, x3
add x0, x8, x2
subs x4, x4, #1
bne LoopY
bne L1Loop
LoopYEnd:
st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x0], #64
st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x0], #64
ld1 {v30.4s, v31.4s}, [x1]
st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x0], #64
st1 {v30.4s, v31.4s}, [x0]
LoopEnd:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret

View File

@ -238,7 +238,7 @@ void MNNConvRunForLineint8_t(float* dst, const int8_t* src, const int8_t* weight
}
}
void MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
void MNNGemmFloatUnit(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, size_t weight_depth_offset) {
MNNGemmFloatCommon_4(dstOrigin, src, weight, src_depth_quad, dst_step, dst_depth_quad, CONVOLUTION_TILED_NUMBER,
weight_depth_offset);

View File

@ -16,11 +16,7 @@
extern "C" {
#endif
#ifdef __aarch64__
#define CONVOLUTION_TILED_NUMBER 14
#else
#define CONVOLUTION_TILED_NUMBER 8
#endif
#define CONV_SETUP_KERNELSIZE(KB) \
int kernel_height = layer->kernelY(); \
@ -95,8 +91,9 @@ void MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* wei
void MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
void MNNGemmFloatUnit_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
void MNNGemmFloatUnit(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, size_t weight_depth_offset);
void MNNGemmFloatOne_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, size_t weight_depth_offset);
void MNNGemmFloatCommon_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,

View File

@ -202,7 +202,7 @@ ErrorCode Convolution3D3x3::onExecute(const std::vector<Tensor*>& inputs, const
const float* _weight = weight + kd * BLOCK_UNIT2 * dc_4 * ic_4 * 16;
for (int i = start; i < end; ++i) {
if (xC == CONVOLUTION_TILED_NUMBER) {
MNNGemmFloatUnit_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
MNNGemmFloatUnit(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
_weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
} else {
MNNGemmFloatCommon_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,

View File

@ -328,7 +328,7 @@ ErrorCode Convolution3x3::onExecute(const std::vector<Tensor*>& inputs, const st
// Multi
if (xC == CONVOLUTION_TILED_NUMBER) {
for (int i = start; i < end; ++i) {
MNNGemmFloatUnit_4(dstOrigin + i * dc_4 * 4 * xC, srcOrigin + i * ic_4 * 4 * xC,
MNNGemmFloatUnit(dstOrigin + i * dc_4 * 4 * xC, srcOrigin + i * ic_4 * 4 * xC,
weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
}
} else {

View File

@ -211,7 +211,7 @@ ErrorCode ConvolutionTiledExecutorBasic::onResize(const std::vector<Tensor*>& in
}
// GEMM
if (xC == CONVOLUTION_TILED_NUMBER) {
MNNGemmFloatUnit_4(dstOrigin + start * 4, colBuffer,
MNNGemmFloatUnit(dstOrigin + start * 4, colBuffer,
weightPtr, icC4 * kernel_width * kernel_height, width * height * 4, ocC4, 0);
} else {
MNNGemmFloatCommon_4(dstOrigin + start * 4, colBuffer,

View File

@ -199,7 +199,7 @@ ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, co
if (xC == CONVOLUTION_TILED_NUMBER) {
for (int i = 0; i < srcUnit2; ++i) {
MNNGemmFloatUnit_4(_dstOrigin + i * dc_4 * 4 * xC, _srcOrigin + i * ic_4 * 4 * xC,
MNNGemmFloatUnit(_dstOrigin + i * dc_4 * 4 * xC, _srcOrigin + i * ic_4 * 4 * xC,
weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
}
} else {

View File

@ -260,7 +260,7 @@ ErrorCode ConvolutionWinograd3D::onExecute(const std::vector<Tensor *> &inputs,
const float* _weight = weight + kd * srcUnit2 * dc_4 * ic_4 * 16;
for (int i = start; i < end; ++i) {
if (xC == CONVOLUTION_TILED_NUMBER) {
MNNGemmFloatUnit_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
MNNGemmFloatUnit(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
_weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
} else {
MNNGemmFloatCommon_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,

View File

@ -60,7 +60,7 @@ static void _winograd(const DeconvolutionWithStride::ComputeUnit& unit, int thre
auto tempSourceAddr = sourceAddr + i * buffer->stride(2);
auto tempColAddr = destAddr + i * unit.dstBuffer->stride(1);
auto weightAddr = unit.weight->host<float>() + unit.weight->stride(0) * i;
MNNGemmFloatUnit_4(tempColAddr, tempSourceAddr, weightAddr, ic_4, CONVOLUTION_TILED_NUMBER * 4, dc_4, 0);
MNNGemmFloatUnit(tempColAddr, tempSourceAddr, weightAddr, ic_4, CONVOLUTION_TILED_NUMBER * 4, dc_4, 0);
}
auto B = unit.winogradInfo.B.get();
auto midAddr = unit.winogradInfo.dstTransformedBuffer->host<float>() +
@ -96,7 +96,7 @@ static void _gemmAndIm2col(const DeconvolutionWithStride::ComputeUnit& unit, int
for (int dy = 0; dy < gDefaultUnit; ++dy) {
for (int dx = 0; dx < gDefaultUnit; ++dx) {
auto tempSourceAddr = srcTotal + (dx + dy * gDefaultUnit) * srcCount;
MNNGemmFloatUnit_4(tempColAddr, tempSourceAddr, weightAddr, icDiv4, CONVOLUTION_TILED_NUMBER * 4, count, 0);
MNNGemmFloatUnit(tempColAddr, tempSourceAddr, weightAddr, icDiv4, CONVOLUTION_TILED_NUMBER * 4, count, 0);
// FUNC_PRINT_ALL(tempColAddr[0], f);
for (int fy = 0; fy < unit.yUnit; ++fy) {

View File

@ -117,7 +117,7 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(const Tensor* AT, const
int lineCount = CONVOLUTION_TILED_NUMBER * 4;
auto aStart = aHost + xStart * 4;
MNNMatrixCopyUnit(tileHost, aStart, lineCount, aStride, l);
MNNGemmFloatUnit_4(cHost + 4 * xStart, tileHost, bHost, l, cStride, h, bExtraStride);
MNNGemmFloatUnit(cHost + 4 * xStart, tileHost, bHost, l, cStride, h, bExtraStride);
}
if (tId != numberThread -1) {
return;
@ -148,9 +148,9 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(const Tensor* AT, const
}
if (e == CONVOLUTION_TILED_NUMBER) {
mFunctions.emplace_back(std::make_pair([aHost, bHost, cHost, l, h, cStride, bStride, numberThread](int tId) {
for (int y=tId; y<h; y+=numberThread) {
MNNGemmFloatUnit_4(cHost + cStride * y, aHost, bHost + bStride * y, l, 0, 1, 0);
}
int yStep = UP_DIV(h, numberThread), yStart = tId * yStep, yNum = ALIMIN(yStart + yStep, h) - yStart;
if (yNum <= 0) return;
MNNGemmFloatUnit(cHost + cStride * yStart, aHost, bHost + bStride * yStart, l, cStride, yNum, 0);
}, numberThread));
} else if (e == 1) {
mFunctions.emplace_back(std::make_pair([aHost, bHost, cHost, l, h, cStride, bStride, numberThread](int tId) {
@ -200,7 +200,6 @@ ErrorCode StrassenMatrixComputor::_generateMatMul(const Tensor* AT, const Tensor
if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBER || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {
return _generateTrivalMatMul(AT, BT, CT);
}
// MNN_PRINT("saveCost = %f, e=%d, l=%d, h=%d\n", saveCost, e, l, h);
// Strassen Construct
auto bn = backend();

View File

@ -16,7 +16,7 @@
extern "C" {
#endif
#define MNN_MEMORY_ALIGN_DEFAULT 32
#define MNN_MEMORY_ALIGN_DEFAULT 64
/**
* @brief alloc memory with given size & alignment.

View File

@ -389,10 +389,10 @@ static int test_main(int argc, const char* argv[]) {
auto outputTensor = net->getSessionOutput(session, NULL);
MNN::Tensor expectTensor(outputTensor, outputTensor->getDimensionType());
outputTensor->copyToHostTensor(&expectTensor);
auto outputFile = pwd + "output.txt";
/*auto outputFile = pwd + "output.txt";
if (outputTensor->size() > 0) {
dumpTensor2File(&expectTensor, outputFile.c_str());
}
}*/
// benchmark. for CPU, op time means calc duration; for others, op time means schedule duration.
{
@ -419,6 +419,16 @@ static int test_main(int argc, const char* argv[]) {
};
if (t > 0) {
#define WARMUP
#ifdef WARMUP
// warmup: 10
for (int warmup = 0; warmup < 10; ++warmup) {
inputTensor->copyFromHostTensor(&givenTensor);
net->runSession(session);
outputTensor->copyToHostTensor(&expectTensor);
}
#endif // WARMUP
std::vector<float> times(t, 0.0f);
for (int i = 0; i < t; ++i) {
auto begin = getTimeInUs();