mirror of https://github.com/alibaba/MNN.git
[PATCH 22/28] [MNN:Speed] 8x8 Gemm and cache prefetch optimize
This commit is contained in:
parent
cf2ddd36a0
commit
c0cb82d9ab
|
@ -1,5 +1,5 @@
|
||||||
//
|
//
|
||||||
// MNNGemmFloatUnit_4.S
|
// MNNGemmFloatUnit.S
|
||||||
// MNN
|
// MNN
|
||||||
//
|
//
|
||||||
// Created by MNN on 2019/02/04.
|
// Created by MNN on 2019/02/04.
|
||||||
|
@ -13,8 +13,8 @@
|
||||||
.text
|
.text
|
||||||
.align 5
|
.align 5
|
||||||
|
|
||||||
asm_function MNNGemmFloatUnit_4
|
asm_function MNNGemmFloatUnit
|
||||||
//void MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
|
//void MNNGemmFloatUnit(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
|
||||||
|
|
||||||
//Auto:
|
//Auto:
|
||||||
//r0:dstOrigin, r1:src, r2: weight, r3:src_depth_quad
|
//r0:dstOrigin, r1:src, r2: weight, r3:src_depth_quad
|
|
@ -0,0 +1,282 @@
|
||||||
|
//
|
||||||
|
// MNNGemmFloatUnit.S
|
||||||
|
// MNN
|
||||||
|
//
|
||||||
|
// Created by MNN on 2019/02/04.
|
||||||
|
// Copyright © 2018, Alibaba Group Holding Limited
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifdef __aarch64__
|
||||||
|
|
||||||
|
#include "MNNAsmGlobal.h"
|
||||||
|
|
||||||
|
.text
|
||||||
|
.align 5
|
||||||
|
|
||||||
|
asm_function MNNGemmFloatUnit
|
||||||
|
//void MNNGemmFloatUnit(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
|
||||||
|
|
||||||
|
//Auto
|
||||||
|
//x0: dst, x1:src, x2:weight, x3:src_depth_quad
|
||||||
|
|
||||||
|
//x4:dst_step, x5:dst_depth_quad, x6: weight_depth_offset
|
||||||
|
|
||||||
|
mov x12, #4 //sizeof(float)
|
||||||
|
mul x4, x12, x4
|
||||||
|
mul x6, x12, x6
|
||||||
|
add x11, x6, x3, LSL #6
|
||||||
|
|
||||||
|
sub sp, sp, #128
|
||||||
|
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||||
|
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||||
|
|
||||||
|
cmp x5, #2
|
||||||
|
blt LoopDzExtra
|
||||||
|
|
||||||
|
LoopDz:
|
||||||
|
mov x8, x1
|
||||||
|
subs x9, x3, #1
|
||||||
|
|
||||||
|
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2]
|
||||||
|
add x2, x2, x11
|
||||||
|
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
|
||||||
|
fmul v16.4s, v8.4s, v0.s[0]
|
||||||
|
fmul v17.4s, v8.4s, v1.s[0]
|
||||||
|
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
|
||||||
|
fmul v18.4s, v8.4s, v2.s[0]
|
||||||
|
fmul v19.4s, v8.4s, v3.s[0]
|
||||||
|
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||||
|
sub x2, x2, x11
|
||||||
|
fmul v20.4s, v8.4s, v4.s[0]
|
||||||
|
fmul v21.4s, v8.4s, v5.s[0]
|
||||||
|
fmul v22.4s, v8.4s, v6.s[0]
|
||||||
|
fmul v23.4s, v8.4s, v7.s[0]
|
||||||
|
fmul v24.4s, v12.4s, v0.s[0]
|
||||||
|
fmul v25.4s, v12.4s, v1.s[0]
|
||||||
|
fmul v26.4s, v12.4s, v2.s[0]
|
||||||
|
fmul v27.4s, v12.4s, v3.s[0]
|
||||||
|
fmul v28.4s, v12.4s, v4.s[0]
|
||||||
|
fmul v29.4s, v12.4s, v5.s[0]
|
||||||
|
fmul v30.4s, v12.4s, v6.s[0]
|
||||||
|
fmul v31.4s, v12.4s, v7.s[0]
|
||||||
|
|
||||||
|
beq L8LoopZEnd
|
||||||
|
L8LoopZ:
|
||||||
|
add x2, x2, #128
|
||||||
|
prfm pldl1keep, [x2]
|
||||||
|
prfm pldl1keep, [x2, x11]
|
||||||
|
sub x2, x2, #128
|
||||||
|
prfm pldl1keep, [x8, #128]
|
||||||
|
prfm pldl1keep, [x8, #192]
|
||||||
|
|
||||||
|
fmla v16.4s, v9.4s, v0.s[1]
|
||||||
|
fmla v17.4s, v9.4s, v1.s[1]
|
||||||
|
fmla v18.4s, v9.4s, v2.s[1]
|
||||||
|
fmla v19.4s, v9.4s, v3.s[1]
|
||||||
|
fmla v20.4s, v9.4s, v4.s[1]
|
||||||
|
fmla v21.4s, v9.4s, v5.s[1]
|
||||||
|
fmla v22.4s, v9.4s, v6.s[1]
|
||||||
|
fmla v23.4s, v9.4s, v7.s[1]
|
||||||
|
fmla v24.4s, v13.4s, v0.s[1]
|
||||||
|
fmla v25.4s, v13.4s, v1.s[1]
|
||||||
|
fmla v26.4s, v13.4s, v2.s[1]
|
||||||
|
fmla v27.4s, v13.4s, v3.s[1]
|
||||||
|
fmla v28.4s, v13.4s, v4.s[1]
|
||||||
|
fmla v29.4s, v13.4s, v5.s[1]
|
||||||
|
fmla v30.4s, v13.4s, v6.s[1]
|
||||||
|
fmla v31.4s, v13.4s, v7.s[1]
|
||||||
|
|
||||||
|
fmla v16.4s, v10.4s, v0.s[2]
|
||||||
|
fmla v17.4s, v10.4s, v1.s[2]
|
||||||
|
fmla v18.4s, v10.4s, v2.s[2]
|
||||||
|
fmla v19.4s, v10.4s, v3.s[2]
|
||||||
|
fmla v20.4s, v10.4s, v4.s[2]
|
||||||
|
fmla v21.4s, v10.4s, v5.s[2]
|
||||||
|
fmla v22.4s, v10.4s, v6.s[2]
|
||||||
|
fmla v23.4s, v10.4s, v7.s[2]
|
||||||
|
fmla v24.4s, v14.4s, v0.s[2]
|
||||||
|
fmla v25.4s, v14.4s, v1.s[2]
|
||||||
|
fmla v26.4s, v14.4s, v2.s[2]
|
||||||
|
fmla v27.4s, v14.4s, v3.s[2]
|
||||||
|
fmla v28.4s, v14.4s, v4.s[2]
|
||||||
|
fmla v29.4s, v14.4s, v5.s[2]
|
||||||
|
fmla v30.4s, v14.4s, v6.s[2]
|
||||||
|
fmla v31.4s, v14.4s, v7.s[2]
|
||||||
|
|
||||||
|
fmla v16.4s, v11.4s, v0.s[3]
|
||||||
|
fmla v17.4s, v11.4s, v1.s[3]
|
||||||
|
fmla v18.4s, v11.4s, v2.s[3]
|
||||||
|
fmla v19.4s, v11.4s, v3.s[3]
|
||||||
|
fmla v20.4s, v11.4s, v4.s[3]
|
||||||
|
fmla v21.4s, v11.4s, v5.s[3]
|
||||||
|
fmla v22.4s, v11.4s, v6.s[3]
|
||||||
|
fmla v23.4s, v11.4s, v7.s[3]
|
||||||
|
fmla v24.4s, v15.4s, v0.s[3]
|
||||||
|
fmla v25.4s, v15.4s, v1.s[3]
|
||||||
|
fmla v26.4s, v15.4s, v2.s[3]
|
||||||
|
fmla v27.4s, v15.4s, v3.s[3]
|
||||||
|
fmla v28.4s, v15.4s, v4.s[3]
|
||||||
|
fmla v29.4s, v15.4s, v5.s[3]
|
||||||
|
fmla v30.4s, v15.4s, v6.s[3]
|
||||||
|
fmla v31.4s, v15.4s, v7.s[3]
|
||||||
|
|
||||||
|
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2]
|
||||||
|
add x2, x2, x11
|
||||||
|
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
|
||||||
|
fmla v16.4s, v8.4s, v0.s[0]
|
||||||
|
fmla v17.4s, v8.4s, v1.s[0]
|
||||||
|
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
|
||||||
|
fmla v18.4s, v8.4s, v2.s[0]
|
||||||
|
fmla v19.4s, v8.4s, v3.s[0]
|
||||||
|
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||||
|
sub x2, x2, x11
|
||||||
|
fmla v20.4s, v8.4s, v4.s[0]
|
||||||
|
fmla v21.4s, v8.4s, v5.s[0]
|
||||||
|
fmla v22.4s, v8.4s, v6.s[0]
|
||||||
|
fmla v23.4s, v8.4s, v7.s[0]
|
||||||
|
fmla v24.4s, v12.4s, v0.s[0]
|
||||||
|
fmla v25.4s, v12.4s, v1.s[0]
|
||||||
|
fmla v26.4s, v12.4s, v2.s[0]
|
||||||
|
fmla v27.4s, v12.4s, v3.s[0]
|
||||||
|
fmla v28.4s, v12.4s, v4.s[0]
|
||||||
|
fmla v29.4s, v12.4s, v5.s[0]
|
||||||
|
fmla v30.4s, v12.4s, v6.s[0]
|
||||||
|
fmla v31.4s, v12.4s, v7.s[0]
|
||||||
|
|
||||||
|
subs x9, x9, #1
|
||||||
|
bne L8LoopZ
|
||||||
|
|
||||||
|
L8LoopZEnd:
|
||||||
|
fmla v16.4s, v9.4s, v0.s[1]
|
||||||
|
fmla v17.4s, v9.4s, v1.s[1]
|
||||||
|
fmla v18.4s, v9.4s, v2.s[1]
|
||||||
|
fmla v19.4s, v9.4s, v3.s[1]
|
||||||
|
fmla v20.4s, v9.4s, v4.s[1]
|
||||||
|
fmla v21.4s, v9.4s, v5.s[1]
|
||||||
|
fmla v22.4s, v9.4s, v6.s[1]
|
||||||
|
fmla v23.4s, v9.4s, v7.s[1]
|
||||||
|
fmla v24.4s, v13.4s, v0.s[1]
|
||||||
|
fmla v25.4s, v13.4s, v1.s[1]
|
||||||
|
fmla v26.4s, v13.4s, v2.s[1]
|
||||||
|
fmla v27.4s, v13.4s, v3.s[1]
|
||||||
|
fmla v28.4s, v13.4s, v4.s[1]
|
||||||
|
fmla v29.4s, v13.4s, v5.s[1]
|
||||||
|
fmla v30.4s, v13.4s, v6.s[1]
|
||||||
|
fmla v31.4s, v13.4s, v7.s[1]
|
||||||
|
|
||||||
|
fmla v16.4s, v10.4s, v0.s[2]
|
||||||
|
fmla v17.4s, v10.4s, v1.s[2]
|
||||||
|
fmla v18.4s, v10.4s, v2.s[2]
|
||||||
|
fmla v19.4s, v10.4s, v3.s[2]
|
||||||
|
fmla v20.4s, v10.4s, v4.s[2]
|
||||||
|
fmla v21.4s, v10.4s, v5.s[2]
|
||||||
|
fmla v22.4s, v10.4s, v6.s[2]
|
||||||
|
fmla v23.4s, v10.4s, v7.s[2]
|
||||||
|
fmla v24.4s, v14.4s, v0.s[2]
|
||||||
|
fmla v25.4s, v14.4s, v1.s[2]
|
||||||
|
fmla v26.4s, v14.4s, v2.s[2]
|
||||||
|
fmla v27.4s, v14.4s, v3.s[2]
|
||||||
|
fmla v28.4s, v14.4s, v4.s[2]
|
||||||
|
fmla v29.4s, v14.4s, v5.s[2]
|
||||||
|
fmla v30.4s, v14.4s, v6.s[2]
|
||||||
|
fmla v31.4s, v14.4s, v7.s[2]
|
||||||
|
|
||||||
|
mov x12, x0
|
||||||
|
|
||||||
|
fmla v16.4s, v11.4s, v0.s[3]
|
||||||
|
fmla v17.4s, v11.4s, v1.s[3]
|
||||||
|
fmla v18.4s, v11.4s, v2.s[3]
|
||||||
|
fmla v19.4s, v11.4s, v3.s[3]
|
||||||
|
fmla v20.4s, v11.4s, v4.s[3]
|
||||||
|
fmla v21.4s, v11.4s, v5.s[3]
|
||||||
|
fmla v22.4s, v11.4s, v6.s[3]
|
||||||
|
fmla v23.4s, v11.4s, v7.s[3]
|
||||||
|
fmla v24.4s, v15.4s, v0.s[3]
|
||||||
|
fmla v25.4s, v15.4s, v1.s[3]
|
||||||
|
fmla v26.4s, v15.4s, v2.s[3]
|
||||||
|
fmla v27.4s, v15.4s, v3.s[3]
|
||||||
|
fmla v28.4s, v15.4s, v4.s[3]
|
||||||
|
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||||
|
fmla v29.4s, v15.4s, v5.s[3]
|
||||||
|
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||||
|
fmla v30.4s, v15.4s, v6.s[3]
|
||||||
|
add x0, x12, x4
|
||||||
|
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
|
||||||
|
add x2, x2, x11
|
||||||
|
fmla v31.4s, v15.4s, v7.s[3]
|
||||||
|
add x2, x2, x6
|
||||||
|
sub x5, x5, #2
|
||||||
|
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
|
||||||
|
add x0, x12, x4, LSL #1
|
||||||
|
|
||||||
|
cmp x5, #1
|
||||||
|
blt LoopDzEnd
|
||||||
|
bgt LoopDz
|
||||||
|
|
||||||
|
LoopDzExtra:
|
||||||
|
|
||||||
|
mov w11, #0
|
||||||
|
mov x8, x1
|
||||||
|
mov x9, x3
|
||||||
|
dup v16.4s, w11
|
||||||
|
dup v17.4s, w11
|
||||||
|
dup v18.4s, w11
|
||||||
|
dup v19.4s, w11
|
||||||
|
dup v20.4s, w11
|
||||||
|
dup v21.4s, w11
|
||||||
|
dup v22.4s, w11
|
||||||
|
dup v23.4s, w11
|
||||||
|
|
||||||
|
L4LoopZ:
|
||||||
|
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||||
|
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
|
||||||
|
fmla v16.4s, v8.4s, v0.s[0]
|
||||||
|
fmla v17.4s, v8.4s, v1.s[0]
|
||||||
|
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
|
||||||
|
fmla v18.4s, v8.4s, v2.s[0]
|
||||||
|
fmla v19.4s, v8.4s, v3.s[0]
|
||||||
|
fmla v20.4s, v8.4s, v4.s[0]
|
||||||
|
fmla v21.4s, v8.4s, v5.s[0]
|
||||||
|
fmla v22.4s, v8.4s, v6.s[0]
|
||||||
|
fmla v23.4s, v8.4s, v7.s[0]
|
||||||
|
|
||||||
|
fmla v16.4s, v9.4s, v0.s[1]
|
||||||
|
fmla v17.4s, v9.4s, v1.s[1]
|
||||||
|
fmla v18.4s, v9.4s, v2.s[1]
|
||||||
|
fmla v19.4s, v9.4s, v3.s[1]
|
||||||
|
fmla v20.4s, v9.4s, v4.s[1]
|
||||||
|
fmla v21.4s, v9.4s, v5.s[1]
|
||||||
|
fmla v22.4s, v9.4s, v6.s[1]
|
||||||
|
fmla v23.4s, v9.4s, v7.s[1]
|
||||||
|
|
||||||
|
fmla v16.4s, v10.4s, v0.s[2]
|
||||||
|
fmla v17.4s, v10.4s, v1.s[2]
|
||||||
|
fmla v18.4s, v10.4s, v2.s[2]
|
||||||
|
fmla v19.4s, v10.4s, v3.s[2]
|
||||||
|
fmla v20.4s, v10.4s, v4.s[2]
|
||||||
|
fmla v21.4s, v10.4s, v5.s[2]
|
||||||
|
fmla v22.4s, v10.4s, v6.s[2]
|
||||||
|
fmla v23.4s, v10.4s, v7.s[2]
|
||||||
|
|
||||||
|
fmla v16.4s, v11.4s, v0.s[3]
|
||||||
|
fmla v17.4s, v11.4s, v1.s[3]
|
||||||
|
fmla v18.4s, v11.4s, v2.s[3]
|
||||||
|
fmla v19.4s, v11.4s, v3.s[3]
|
||||||
|
fmla v20.4s, v11.4s, v4.s[3]
|
||||||
|
fmla v21.4s, v11.4s, v5.s[3]
|
||||||
|
fmla v22.4s, v11.4s, v6.s[3]
|
||||||
|
fmla v23.4s, v11.4s, v7.s[3]
|
||||||
|
|
||||||
|
subs x9, x9, #1
|
||||||
|
bne L4LoopZ
|
||||||
|
|
||||||
|
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||||
|
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||||
|
|
||||||
|
LoopDzEnd:
|
||||||
|
sub sp, sp, #128
|
||||||
|
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||||
|
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||||
|
|
||||||
|
ret
|
||||||
|
#endif
|
|
@ -1,192 +0,0 @@
|
||||||
//
|
|
||||||
// MNNGemmFloatUnit_4.S
|
|
||||||
// MNN
|
|
||||||
//
|
|
||||||
// Created by MNN on 2019/02/04.
|
|
||||||
// Copyright © 2018, Alibaba Group Holding Limited
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifdef __aarch64__
|
|
||||||
|
|
||||||
#include "MNNAsmGlobal.h"
|
|
||||||
|
|
||||||
.text
|
|
||||||
.align 5
|
|
||||||
|
|
||||||
asm_function MNNGemmFloatUnit_4
|
|
||||||
//void MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t weight_depth_offset)
|
|
||||||
|
|
||||||
//Auto
|
|
||||||
//x0: dst, x1:src, x2:weight, x3:src_depth_quad
|
|
||||||
|
|
||||||
//x4:dst_step, x5:dst_depth_quad, x6: weight_depth_offset
|
|
||||||
|
|
||||||
mov x12, #4//sizeof(float)
|
|
||||||
mul x4, x12, x4
|
|
||||||
mul x6, x12, x6
|
|
||||||
|
|
||||||
sub sp, sp, #128
|
|
||||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
|
||||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
|
||||||
|
|
||||||
LoopDz:
|
|
||||||
mov x8, x1
|
|
||||||
subs x9, x3, #1
|
|
||||||
|
|
||||||
ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x2], #64
|
|
||||||
|
|
||||||
ld1 {v0.4s, v1.4s}, [x8], #32
|
|
||||||
fmul v18.4s, v14.4s, v0.s[0]
|
|
||||||
ld1 {v2.4s, v3.4s}, [x8], #32
|
|
||||||
fmul v19.4s, v14.4s, v1.s[0]
|
|
||||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
|
|
||||||
fmul v20.4s, v14.4s, v2.s[0]
|
|
||||||
fmul v21.4s, v14.4s, v3.s[0]
|
|
||||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x8], #64
|
|
||||||
fmul v22.4s, v14.4s, v4.s[0]
|
|
||||||
fmul v23.4s, v14.4s, v5.s[0]
|
|
||||||
ld1 {v12.4s, v13.4s}, [x8], #32
|
|
||||||
fmul v24.4s, v14.4s, v6.s[0]
|
|
||||||
fmul v25.4s, v14.4s, v7.s[0]
|
|
||||||
fmul v26.4s, v14.4s, v8.s[0]
|
|
||||||
fmul v27.4s, v14.4s, v9.s[0]
|
|
||||||
fmul v28.4s, v14.4s, v10.s[0]
|
|
||||||
fmul v29.4s, v14.4s, v11.s[0]
|
|
||||||
fmul v30.4s, v14.4s, v12.s[0]
|
|
||||||
fmul v31.4s, v14.4s, v13.s[0]
|
|
||||||
|
|
||||||
beq L14LoopZEnd
|
|
||||||
L14LoopZ:
|
|
||||||
fmla v18.4s, v15.4s, v0.s[1]
|
|
||||||
fmla v19.4s, v15.4s, v1.s[1]
|
|
||||||
fmla v20.4s, v15.4s, v2.s[1]
|
|
||||||
fmla v21.4s, v15.4s, v3.s[1]
|
|
||||||
fmla v22.4s, v15.4s, v4.s[1]
|
|
||||||
fmla v23.4s, v15.4s, v5.s[1]
|
|
||||||
fmla v24.4s, v15.4s, v6.s[1]
|
|
||||||
fmla v25.4s, v15.4s, v7.s[1]
|
|
||||||
fmla v26.4s, v15.4s, v8.s[1]
|
|
||||||
fmla v27.4s, v15.4s, v9.s[1]
|
|
||||||
fmla v28.4s, v15.4s, v10.s[1]
|
|
||||||
fmla v29.4s, v15.4s, v11.s[1]
|
|
||||||
fmla v30.4s, v15.4s, v12.s[1]
|
|
||||||
fmla v31.4s, v15.4s, v13.s[1]
|
|
||||||
|
|
||||||
fmla v18.4s, v16.4s, v0.s[2]
|
|
||||||
fmla v19.4s, v16.4s, v1.s[2]
|
|
||||||
fmla v20.4s, v16.4s, v2.s[2]
|
|
||||||
fmla v21.4s, v16.4s, v3.s[2]
|
|
||||||
fmla v22.4s, v16.4s, v4.s[2]
|
|
||||||
fmla v23.4s, v16.4s, v5.s[2]
|
|
||||||
fmla v24.4s, v16.4s, v6.s[2]
|
|
||||||
fmla v25.4s, v16.4s, v7.s[2]
|
|
||||||
fmla v26.4s, v16.4s, v8.s[2]
|
|
||||||
fmla v27.4s, v16.4s, v9.s[2]
|
|
||||||
fmla v28.4s, v16.4s, v10.s[2]
|
|
||||||
fmla v29.4s, v16.4s, v11.s[2]
|
|
||||||
fmla v30.4s, v16.4s, v12.s[2]
|
|
||||||
fmla v31.4s, v16.4s, v13.s[2]
|
|
||||||
|
|
||||||
fmla v18.4s, v17.4s, v0.s[3]
|
|
||||||
fmla v19.4s, v17.4s, v1.s[3]
|
|
||||||
fmla v20.4s, v17.4s, v2.s[3]
|
|
||||||
fmla v21.4s, v17.4s, v3.s[3]
|
|
||||||
fmla v22.4s, v17.4s, v4.s[3]
|
|
||||||
fmla v23.4s, v17.4s, v5.s[3]
|
|
||||||
fmla v24.4s, v17.4s, v6.s[3]
|
|
||||||
fmla v25.4s, v17.4s, v7.s[3]
|
|
||||||
fmla v26.4s, v17.4s, v8.s[3]
|
|
||||||
fmla v27.4s, v17.4s, v9.s[3]
|
|
||||||
fmla v28.4s, v17.4s, v10.s[3]
|
|
||||||
fmla v29.4s, v17.4s, v11.s[3]
|
|
||||||
fmla v30.4s, v17.4s, v12.s[3]
|
|
||||||
fmla v31.4s, v17.4s, v13.s[3]
|
|
||||||
|
|
||||||
ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x2], #64
|
|
||||||
|
|
||||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64
|
|
||||||
fmla v18.4s, v14.4s, v0.s[0]
|
|
||||||
fmla v19.4s, v14.4s, v1.s[0]
|
|
||||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64
|
|
||||||
fmla v20.4s, v14.4s, v2.s[0]
|
|
||||||
fmla v21.4s, v14.4s, v3.s[0]
|
|
||||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x8], #64
|
|
||||||
fmla v22.4s, v14.4s, v4.s[0]
|
|
||||||
fmla v23.4s, v14.4s, v5.s[0]
|
|
||||||
ld1 {v12.4s, v13.4s}, [x8], #32
|
|
||||||
fmla v24.4s, v14.4s, v6.s[0]
|
|
||||||
fmla v25.4s, v14.4s, v7.s[0]
|
|
||||||
fmla v26.4s, v14.4s, v8.s[0]
|
|
||||||
fmla v27.4s, v14.4s, v9.s[0]
|
|
||||||
fmla v28.4s, v14.4s, v10.s[0]
|
|
||||||
fmla v29.4s, v14.4s, v11.s[0]
|
|
||||||
fmla v30.4s, v14.4s, v12.s[0]
|
|
||||||
fmla v31.4s, v14.4s, v13.s[0]
|
|
||||||
|
|
||||||
subs x9, x9, #1
|
|
||||||
bne L14LoopZ
|
|
||||||
|
|
||||||
L14LoopZEnd:
|
|
||||||
fmla v18.4s, v15.4s, v0.s[1]
|
|
||||||
fmla v19.4s, v15.4s, v1.s[1]
|
|
||||||
fmla v20.4s, v15.4s, v2.s[1]
|
|
||||||
fmla v21.4s, v15.4s, v3.s[1]
|
|
||||||
fmla v22.4s, v15.4s, v4.s[1]
|
|
||||||
fmla v23.4s, v15.4s, v5.s[1]
|
|
||||||
fmla v24.4s, v15.4s, v6.s[1]
|
|
||||||
fmla v25.4s, v15.4s, v7.s[1]
|
|
||||||
fmla v26.4s, v15.4s, v8.s[1]
|
|
||||||
fmla v27.4s, v15.4s, v9.s[1]
|
|
||||||
fmla v28.4s, v15.4s, v10.s[1]
|
|
||||||
fmla v29.4s, v15.4s, v11.s[1]
|
|
||||||
fmla v30.4s, v15.4s, v12.s[1]
|
|
||||||
fmla v31.4s, v15.4s, v13.s[1]
|
|
||||||
|
|
||||||
fmla v18.4s, v16.4s, v0.s[2]
|
|
||||||
fmla v19.4s, v16.4s, v1.s[2]
|
|
||||||
fmla v20.4s, v16.4s, v2.s[2]
|
|
||||||
fmla v21.4s, v16.4s, v3.s[2]
|
|
||||||
fmla v22.4s, v16.4s, v4.s[2]
|
|
||||||
fmla v23.4s, v16.4s, v5.s[2]
|
|
||||||
fmla v24.4s, v16.4s, v6.s[2]
|
|
||||||
fmla v25.4s, v16.4s, v7.s[2]
|
|
||||||
fmla v26.4s, v16.4s, v8.s[2]
|
|
||||||
fmla v27.4s, v16.4s, v9.s[2]
|
|
||||||
fmla v28.4s, v16.4s, v10.s[2]
|
|
||||||
fmla v29.4s, v16.4s, v11.s[2]
|
|
||||||
fmla v30.4s, v16.4s, v12.s[2]
|
|
||||||
fmla v31.4s, v16.4s, v13.s[2]
|
|
||||||
|
|
||||||
mov x12, x0
|
|
||||||
|
|
||||||
fmla v18.4s, v17.4s, v0.s[3]
|
|
||||||
fmla v19.4s, v17.4s, v1.s[3]
|
|
||||||
fmla v20.4s, v17.4s, v2.s[3]
|
|
||||||
fmla v21.4s, v17.4s, v3.s[3]
|
|
||||||
fmla v22.4s, v17.4s, v4.s[3]
|
|
||||||
st1 {v18.4s, v19.4s}, [x0], #32
|
|
||||||
fmla v23.4s, v17.4s, v5.s[3]
|
|
||||||
fmla v24.4s, v17.4s, v6.s[3]
|
|
||||||
fmla v25.4s, v17.4s, v7.s[3]
|
|
||||||
fmla v26.4s, v17.4s, v8.s[3]
|
|
||||||
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
|
||||||
fmla v27.4s, v17.4s, v9.s[3]
|
|
||||||
fmla v28.4s, v17.4s, v10.s[3]
|
|
||||||
fmla v29.4s, v17.4s, v11.s[3]
|
|
||||||
fmla v30.4s, v17.4s, v12.s[3]
|
|
||||||
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
|
|
||||||
fmla v31.4s, v17.4s, v13.s[3]
|
|
||||||
add x2, x2, x6
|
|
||||||
|
|
||||||
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
|
|
||||||
|
|
||||||
subs x5, x5, #1
|
|
||||||
add x0, x12, x4
|
|
||||||
|
|
||||||
bne LoopDz
|
|
||||||
sub sp, sp, #128
|
|
||||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
|
||||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
|
||||||
|
|
||||||
ret
|
|
||||||
#endif
|
|
|
@ -19,45 +19,88 @@ asm_function MNNMatrixCopyUnit
|
||||||
//Auto: x0: C, x1:A, x2:cStride
|
//Auto: x0: C, x1:A, x2:cStride
|
||||||
//x3:aStride, x4:height
|
//x3:aStride, x4:height
|
||||||
|
|
||||||
|
sub sp, sp, #128
|
||||||
|
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||||
|
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||||
|
|
||||||
mov x12, #4 //sizeof(float)
|
mov x12, #4 //sizeof(float)
|
||||||
mul x2, x12, x2
|
mul x2, x12, x2
|
||||||
mul x3, x12, x3
|
mul x3, x12, x3
|
||||||
|
mov x12, #4 // cache prefetch param
|
||||||
|
mul x5, x3, x12
|
||||||
|
|
||||||
|
cmp x4, #4
|
||||||
|
blt L1Loop
|
||||||
|
|
||||||
|
L4Loop:
|
||||||
|
add x9, x1, x5
|
||||||
|
prfm pldl1keep, [x9]
|
||||||
|
prfm pldl1keep, [x9, #64]
|
||||||
|
add x9, x9, x3
|
||||||
|
prfm pldl1keep, [x9]
|
||||||
|
prfm pldl1keep, [x9, #64]
|
||||||
|
add x9, x9, x3
|
||||||
|
prfm pldl1keep, [x9]
|
||||||
|
prfm pldl1keep, [x9, #64]
|
||||||
|
add x9, x9, x3
|
||||||
|
prfm pldl1keep, [x9]
|
||||||
|
prfm pldl1keep, [x9, #64]
|
||||||
|
|
||||||
subs x4, x4, #1
|
|
||||||
mov x8, x0
|
|
||||||
mov x9, x1
|
mov x9, x1
|
||||||
ld1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x1], #64
|
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||||
ld1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x1], #64
|
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||||
ld1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x1], #64
|
|
||||||
|
|
||||||
beq LoopYEnd
|
|
||||||
|
|
||||||
LoopY:
|
|
||||||
// Unit = 14 for arm64a
|
|
||||||
st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x0], #64
|
|
||||||
st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x0], #64
|
|
||||||
ld1 {v30.4s, v31.4s}, [x1]
|
|
||||||
st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x0], #64
|
|
||||||
st1 {v30.4s, v31.4s}, [x0]
|
|
||||||
add x1, x9, x3
|
add x1, x9, x3
|
||||||
|
mov x9, x1
|
||||||
|
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
|
||||||
|
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
|
||||||
|
add x1, x9, x3
|
||||||
|
mov x9, x1
|
||||||
|
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
|
||||||
|
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64
|
||||||
|
add x1, x9, x3
|
||||||
|
mov x9, x1
|
||||||
|
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64
|
||||||
|
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x1], #64
|
||||||
|
add x1, x9, x3
|
||||||
|
mov x8, x0
|
||||||
|
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||||
|
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||||
add x0, x8, x2
|
add x0, x8, x2
|
||||||
mov x8, x0
|
mov x8, x0
|
||||||
|
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
|
||||||
|
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
|
||||||
|
add x0, x8, x2
|
||||||
|
mov x8, x0
|
||||||
|
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||||
|
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||||
|
add x0, x8, x2
|
||||||
|
mov x8, x0
|
||||||
|
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
|
||||||
|
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
|
||||||
|
add x0, x8, x2
|
||||||
|
subs x4, x4, #4
|
||||||
|
cmp x4, #3
|
||||||
|
bgt L4Loop
|
||||||
|
|
||||||
|
cbz x4, LoopEnd
|
||||||
|
|
||||||
|
L1Loop:
|
||||||
mov x9, x1
|
mov x9, x1
|
||||||
ld1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x1], #64
|
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||||
ld1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x1], #64
|
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||||
ld1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x1], #64
|
mov x8, x0
|
||||||
|
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||||
|
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||||
|
add x1, x9, x3
|
||||||
|
add x0, x8, x2
|
||||||
subs x4, x4, #1
|
subs x4, x4, #1
|
||||||
bne LoopY
|
bne L1Loop
|
||||||
|
|
||||||
LoopYEnd:
|
LoopEnd:
|
||||||
|
|
||||||
st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x0], #64
|
|
||||||
st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [x0], #64
|
|
||||||
ld1 {v30.4s, v31.4s}, [x1]
|
|
||||||
st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [x0], #64
|
|
||||||
st1 {v30.4s, v31.4s}, [x0]
|
|
||||||
|
|
||||||
|
sub sp, sp, #128
|
||||||
|
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||||
|
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||||
|
|
||||||
ret
|
ret
|
||||||
|
|
||||||
|
|
|
@ -238,7 +238,7 @@ void MNNConvRunForLineint8_t(float* dst, const int8_t* src, const int8_t* weight
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MNNGemmFloatUnit_4(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
void MNNGemmFloatUnit(float* dstOrigin, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
||||||
size_t dst_depth_quad, size_t weight_depth_offset) {
|
size_t dst_depth_quad, size_t weight_depth_offset) {
|
||||||
MNNGemmFloatCommon_4(dstOrigin, src, weight, src_depth_quad, dst_step, dst_depth_quad, CONVOLUTION_TILED_NUMBER,
|
MNNGemmFloatCommon_4(dstOrigin, src, weight, src_depth_quad, dst_step, dst_depth_quad, CONVOLUTION_TILED_NUMBER,
|
||||||
weight_depth_offset);
|
weight_depth_offset);
|
||||||
|
|
|
@ -16,11 +16,7 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __aarch64__
|
|
||||||
#define CONVOLUTION_TILED_NUMBER 14
|
|
||||||
#else
|
|
||||||
#define CONVOLUTION_TILED_NUMBER 8
|
#define CONVOLUTION_TILED_NUMBER 8
|
||||||
#endif
|
|
||||||
|
|
||||||
#define CONV_SETUP_KERNELSIZE(KB) \
|
#define CONV_SETUP_KERNELSIZE(KB) \
|
||||||
int kernel_height = layer->kernelY(); \
|
int kernel_height = layer->kernelY(); \
|
||||||
|
@ -95,8 +91,9 @@ void MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* wei
|
||||||
void MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
|
void MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
|
||||||
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
|
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
|
||||||
|
|
||||||
void MNNGemmFloatUnit_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
void MNNGemmFloatUnit(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
||||||
size_t dst_depth_quad, size_t weight_depth_offset);
|
size_t dst_depth_quad, size_t weight_depth_offset);
|
||||||
|
|
||||||
void MNNGemmFloatOne_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
void MNNGemmFloatOne_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
||||||
size_t dst_depth_quad, size_t weight_depth_offset);
|
size_t dst_depth_quad, size_t weight_depth_offset);
|
||||||
void MNNGemmFloatCommon_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
void MNNGemmFloatCommon_4(float* dst, const float* src, const float* weight, size_t src_depth_quad, size_t dst_step,
|
||||||
|
|
|
@ -202,7 +202,7 @@ ErrorCode Convolution3D3x3::onExecute(const std::vector<Tensor*>& inputs, const
|
||||||
const float* _weight = weight + kd * BLOCK_UNIT2 * dc_4 * ic_4 * 16;
|
const float* _weight = weight + kd * BLOCK_UNIT2 * dc_4 * ic_4 * 16;
|
||||||
for (int i = start; i < end; ++i) {
|
for (int i = start; i < end; ++i) {
|
||||||
if (xC == CONVOLUTION_TILED_NUMBER) {
|
if (xC == CONVOLUTION_TILED_NUMBER) {
|
||||||
MNNGemmFloatUnit_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
MNNGemmFloatUnit(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
||||||
_weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
_weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
||||||
} else {
|
} else {
|
||||||
MNNGemmFloatCommon_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
MNNGemmFloatCommon_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
||||||
|
|
|
@ -328,7 +328,7 @@ ErrorCode Convolution3x3::onExecute(const std::vector<Tensor*>& inputs, const st
|
||||||
// Multi
|
// Multi
|
||||||
if (xC == CONVOLUTION_TILED_NUMBER) {
|
if (xC == CONVOLUTION_TILED_NUMBER) {
|
||||||
for (int i = start; i < end; ++i) {
|
for (int i = start; i < end; ++i) {
|
||||||
MNNGemmFloatUnit_4(dstOrigin + i * dc_4 * 4 * xC, srcOrigin + i * ic_4 * 4 * xC,
|
MNNGemmFloatUnit(dstOrigin + i * dc_4 * 4 * xC, srcOrigin + i * ic_4 * 4 * xC,
|
||||||
weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -211,7 +211,7 @@ ErrorCode ConvolutionTiledExecutorBasic::onResize(const std::vector<Tensor*>& in
|
||||||
}
|
}
|
||||||
// GEMM
|
// GEMM
|
||||||
if (xC == CONVOLUTION_TILED_NUMBER) {
|
if (xC == CONVOLUTION_TILED_NUMBER) {
|
||||||
MNNGemmFloatUnit_4(dstOrigin + start * 4, colBuffer,
|
MNNGemmFloatUnit(dstOrigin + start * 4, colBuffer,
|
||||||
weightPtr, icC4 * kernel_width * kernel_height, width * height * 4, ocC4, 0);
|
weightPtr, icC4 * kernel_width * kernel_height, width * height * 4, ocC4, 0);
|
||||||
} else {
|
} else {
|
||||||
MNNGemmFloatCommon_4(dstOrigin + start * 4, colBuffer,
|
MNNGemmFloatCommon_4(dstOrigin + start * 4, colBuffer,
|
||||||
|
|
|
@ -199,7 +199,7 @@ ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, co
|
||||||
|
|
||||||
if (xC == CONVOLUTION_TILED_NUMBER) {
|
if (xC == CONVOLUTION_TILED_NUMBER) {
|
||||||
for (int i = 0; i < srcUnit2; ++i) {
|
for (int i = 0; i < srcUnit2; ++i) {
|
||||||
MNNGemmFloatUnit_4(_dstOrigin + i * dc_4 * 4 * xC, _srcOrigin + i * ic_4 * 4 * xC,
|
MNNGemmFloatUnit(_dstOrigin + i * dc_4 * 4 * xC, _srcOrigin + i * ic_4 * 4 * xC,
|
||||||
weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -260,7 +260,7 @@ ErrorCode ConvolutionWinograd3D::onExecute(const std::vector<Tensor *> &inputs,
|
||||||
const float* _weight = weight + kd * srcUnit2 * dc_4 * ic_4 * 16;
|
const float* _weight = weight + kd * srcUnit2 * dc_4 * ic_4 * 16;
|
||||||
for (int i = start; i < end; ++i) {
|
for (int i = start; i < end; ++i) {
|
||||||
if (xC == CONVOLUTION_TILED_NUMBER) {
|
if (xC == CONVOLUTION_TILED_NUMBER) {
|
||||||
MNNGemmFloatUnit_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
MNNGemmFloatUnit(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
||||||
_weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
_weight + i * 16 * ic_4 * dc_4, ic_4, xC * 4, dc_4, 0);
|
||||||
} else {
|
} else {
|
||||||
MNNGemmFloatCommon_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
MNNGemmFloatCommon_4(tempDst + i * dc_4 * xC * 4, _srcOrigin + i * ic_4 * 4 * xC,
|
||||||
|
|
|
@ -60,7 +60,7 @@ static void _winograd(const DeconvolutionWithStride::ComputeUnit& unit, int thre
|
||||||
auto tempSourceAddr = sourceAddr + i * buffer->stride(2);
|
auto tempSourceAddr = sourceAddr + i * buffer->stride(2);
|
||||||
auto tempColAddr = destAddr + i * unit.dstBuffer->stride(1);
|
auto tempColAddr = destAddr + i * unit.dstBuffer->stride(1);
|
||||||
auto weightAddr = unit.weight->host<float>() + unit.weight->stride(0) * i;
|
auto weightAddr = unit.weight->host<float>() + unit.weight->stride(0) * i;
|
||||||
MNNGemmFloatUnit_4(tempColAddr, tempSourceAddr, weightAddr, ic_4, CONVOLUTION_TILED_NUMBER * 4, dc_4, 0);
|
MNNGemmFloatUnit(tempColAddr, tempSourceAddr, weightAddr, ic_4, CONVOLUTION_TILED_NUMBER * 4, dc_4, 0);
|
||||||
}
|
}
|
||||||
auto B = unit.winogradInfo.B.get();
|
auto B = unit.winogradInfo.B.get();
|
||||||
auto midAddr = unit.winogradInfo.dstTransformedBuffer->host<float>() +
|
auto midAddr = unit.winogradInfo.dstTransformedBuffer->host<float>() +
|
||||||
|
@ -96,7 +96,7 @@ static void _gemmAndIm2col(const DeconvolutionWithStride::ComputeUnit& unit, int
|
||||||
for (int dy = 0; dy < gDefaultUnit; ++dy) {
|
for (int dy = 0; dy < gDefaultUnit; ++dy) {
|
||||||
for (int dx = 0; dx < gDefaultUnit; ++dx) {
|
for (int dx = 0; dx < gDefaultUnit; ++dx) {
|
||||||
auto tempSourceAddr = srcTotal + (dx + dy * gDefaultUnit) * srcCount;
|
auto tempSourceAddr = srcTotal + (dx + dy * gDefaultUnit) * srcCount;
|
||||||
MNNGemmFloatUnit_4(tempColAddr, tempSourceAddr, weightAddr, icDiv4, CONVOLUTION_TILED_NUMBER * 4, count, 0);
|
MNNGemmFloatUnit(tempColAddr, tempSourceAddr, weightAddr, icDiv4, CONVOLUTION_TILED_NUMBER * 4, count, 0);
|
||||||
// FUNC_PRINT_ALL(tempColAddr[0], f);
|
// FUNC_PRINT_ALL(tempColAddr[0], f);
|
||||||
|
|
||||||
for (int fy = 0; fy < unit.yUnit; ++fy) {
|
for (int fy = 0; fy < unit.yUnit; ++fy) {
|
||||||
|
|
|
@ -117,7 +117,7 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(const Tensor* AT, const
|
||||||
int lineCount = CONVOLUTION_TILED_NUMBER * 4;
|
int lineCount = CONVOLUTION_TILED_NUMBER * 4;
|
||||||
auto aStart = aHost + xStart * 4;
|
auto aStart = aHost + xStart * 4;
|
||||||
MNNMatrixCopyUnit(tileHost, aStart, lineCount, aStride, l);
|
MNNMatrixCopyUnit(tileHost, aStart, lineCount, aStride, l);
|
||||||
MNNGemmFloatUnit_4(cHost + 4 * xStart, tileHost, bHost, l, cStride, h, bExtraStride);
|
MNNGemmFloatUnit(cHost + 4 * xStart, tileHost, bHost, l, cStride, h, bExtraStride);
|
||||||
}
|
}
|
||||||
if (tId != numberThread -1) {
|
if (tId != numberThread -1) {
|
||||||
return;
|
return;
|
||||||
|
@ -148,9 +148,9 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(const Tensor* AT, const
|
||||||
}
|
}
|
||||||
if (e == CONVOLUTION_TILED_NUMBER) {
|
if (e == CONVOLUTION_TILED_NUMBER) {
|
||||||
mFunctions.emplace_back(std::make_pair([aHost, bHost, cHost, l, h, cStride, bStride, numberThread](int tId) {
|
mFunctions.emplace_back(std::make_pair([aHost, bHost, cHost, l, h, cStride, bStride, numberThread](int tId) {
|
||||||
for (int y=tId; y<h; y+=numberThread) {
|
int yStep = UP_DIV(h, numberThread), yStart = tId * yStep, yNum = ALIMIN(yStart + yStep, h) - yStart;
|
||||||
MNNGemmFloatUnit_4(cHost + cStride * y, aHost, bHost + bStride * y, l, 0, 1, 0);
|
if (yNum <= 0) return;
|
||||||
}
|
MNNGemmFloatUnit(cHost + cStride * yStart, aHost, bHost + bStride * yStart, l, cStride, yNum, 0);
|
||||||
}, numberThread));
|
}, numberThread));
|
||||||
} else if (e == 1) {
|
} else if (e == 1) {
|
||||||
mFunctions.emplace_back(std::make_pair([aHost, bHost, cHost, l, h, cStride, bStride, numberThread](int tId) {
|
mFunctions.emplace_back(std::make_pair([aHost, bHost, cHost, l, h, cStride, bStride, numberThread](int tId) {
|
||||||
|
@ -200,7 +200,6 @@ ErrorCode StrassenMatrixComputor::_generateMatMul(const Tensor* AT, const Tensor
|
||||||
if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBER || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {
|
if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBER || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {
|
||||||
return _generateTrivalMatMul(AT, BT, CT);
|
return _generateTrivalMatMul(AT, BT, CT);
|
||||||
}
|
}
|
||||||
// MNN_PRINT("saveCost = %f, e=%d, l=%d, h=%d\n", saveCost, e, l, h);
|
|
||||||
|
|
||||||
// Strassen Construct
|
// Strassen Construct
|
||||||
auto bn = backend();
|
auto bn = backend();
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MNN_MEMORY_ALIGN_DEFAULT 32
|
#define MNN_MEMORY_ALIGN_DEFAULT 64
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief alloc memory with given size & alignment.
|
* @brief alloc memory with given size & alignment.
|
||||||
|
|
|
@ -389,10 +389,10 @@ static int test_main(int argc, const char* argv[]) {
|
||||||
auto outputTensor = net->getSessionOutput(session, NULL);
|
auto outputTensor = net->getSessionOutput(session, NULL);
|
||||||
MNN::Tensor expectTensor(outputTensor, outputTensor->getDimensionType());
|
MNN::Tensor expectTensor(outputTensor, outputTensor->getDimensionType());
|
||||||
outputTensor->copyToHostTensor(&expectTensor);
|
outputTensor->copyToHostTensor(&expectTensor);
|
||||||
auto outputFile = pwd + "output.txt";
|
/*auto outputFile = pwd + "output.txt";
|
||||||
if (outputTensor->size() > 0) {
|
if (outputTensor->size() > 0) {
|
||||||
dumpTensor2File(&expectTensor, outputFile.c_str());
|
dumpTensor2File(&expectTensor, outputFile.c_str());
|
||||||
}
|
}*/
|
||||||
|
|
||||||
// benchmark. for CPU, op time means calc duration; for others, op time means schedule duration.
|
// benchmark. for CPU, op time means calc duration; for others, op time means schedule duration.
|
||||||
{
|
{
|
||||||
|
@ -419,6 +419,16 @@ static int test_main(int argc, const char* argv[]) {
|
||||||
};
|
};
|
||||||
|
|
||||||
if (t > 0) {
|
if (t > 0) {
|
||||||
|
#define WARMUP
|
||||||
|
#ifdef WARMUP
|
||||||
|
// warmup: 10
|
||||||
|
for (int warmup = 0; warmup < 10; ++warmup) {
|
||||||
|
inputTensor->copyFromHostTensor(&givenTensor);
|
||||||
|
net->runSession(session);
|
||||||
|
outputTensor->copyToHostTensor(&expectTensor);
|
||||||
|
}
|
||||||
|
#endif // WARMUP
|
||||||
|
|
||||||
std::vector<float> times(t, 0.0f);
|
std::vector<float> times(t, 0.0f);
|
||||||
for (int i = 0; i < t; ++i) {
|
for (int i = 0; i < t; ++i) {
|
||||||
auto begin = getTimeInUs();
|
auto begin = getTimeInUs();
|
||||||
|
|
Loading…
Reference in New Issue