2021-09-18 15:52:30 +08:00
|
|
|
//
|
|
|
|
// AVX2Functions.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on b'2021/05/17'.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
#include "AVX2Functions.hpp"
|
|
|
|
#include "AVX2Backend.hpp"
|
|
|
|
#include "avx/FunctionSummary.hpp"
|
|
|
|
#include "avxfma/FunctionSummary.hpp"
|
|
|
|
#include "avx512/FunctionSummary.hpp"
|
|
|
|
namespace MNN {
|
|
|
|
struct MatMulPackParam {
|
|
|
|
int eP;
|
|
|
|
int lP;
|
|
|
|
int hP;
|
|
|
|
};
|
|
|
|
|
|
|
|
static MatMulPackParam gPackInfo;
|
|
|
|
static CoreFunctions* gAVX2CoreFunctions = nullptr;
|
2021-09-18 15:52:30 +08:00
|
|
|
static CoreInt8Functions* gAVX2CoreInt8Functions = nullptr;
|
2021-06-11 17:17:13 +08:00
|
|
|
static void _MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
|
|
|
*eP = gPackInfo.eP;
|
|
|
|
*lP = gPackInfo.lP;
|
|
|
|
*hP = gPackInfo.hP;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool AVX2Functions::init(int cpuFlags) {
|
|
|
|
gAVX2CoreFunctions = new CoreFunctions;
|
|
|
|
auto coreFunction = gAVX2CoreFunctions;
|
2021-09-18 15:52:30 +08:00
|
|
|
gAVX2CoreInt8Functions = new CoreInt8Functions;
|
2021-06-11 17:17:13 +08:00
|
|
|
// Init default functions
|
|
|
|
*coreFunction = *MNNGetCoreFunctions();
|
2021-09-18 15:52:30 +08:00
|
|
|
*gAVX2CoreInt8Functions = *MNNGetInt8CoreFunctions();
|
|
|
|
_AVX_MNNInt8FunctionInit(gAVX2CoreInt8Functions);
|
2021-06-11 17:17:13 +08:00
|
|
|
// Init AVX2
|
|
|
|
coreFunction->MNNGetMatMulPackMode = _MNNGetMatMulPackMode;
|
|
|
|
gPackInfo.eP = 24;
|
|
|
|
gPackInfo.lP = 1;
|
|
|
|
gPackInfo.hP = 4;
|
2021-09-18 15:52:30 +08:00
|
|
|
_AVX_ReorderInit(coreFunction);
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMul;
|
|
|
|
coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemain;
|
|
|
|
coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A;
|
|
|
|
coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B;
|
|
|
|
coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1;
|
|
|
|
coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1;
|
2021-09-18 15:52:30 +08:00
|
|
|
|
|
|
|
// For Packed Functions
|
|
|
|
coreFunction->pack = 8;
|
2021-06-11 17:17:13 +08:00
|
|
|
_AVX_ExtraInit(coreFunction);
|
|
|
|
// Winograd
|
|
|
|
_AVX_WinogradInit(coreFunction);
|
|
|
|
if (cpuFlags & libyuv::kCpuHasFMA3) {
|
|
|
|
coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMulFMA;
|
|
|
|
coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemainFMA;
|
|
|
|
coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1FMA;
|
|
|
|
coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1FMA;
|
2021-09-18 15:52:30 +08:00
|
|
|
_AVX_ExtraInitFMA(coreFunction);
|
2021-06-11 17:17:13 +08:00
|
|
|
}
|
|
|
|
#ifdef MNN_AVX512
|
|
|
|
if ((cpuFlags & libyuv::kCpuHasAVX512VNNI)
|
|
|
|
|| (cpuFlags & libyuv::kCpuHasAVX512VL)
|
|
|
|
|| (cpuFlags & libyuv::kCpuHasAVX512BW)
|
|
|
|
|| (cpuFlags & libyuv::kCpuHasAVX512VBMI)
|
|
|
|
|| (cpuFlags & libyuv::kCpuHasAVX512VBITALG)
|
|
|
|
|| (cpuFlags & libyuv::kCpuHasAVX512VPOPCNTDQ)
|
|
|
|
|| (cpuFlags & libyuv::kCpuHasAVX512VBMI2)
|
|
|
|
) {
|
2021-09-18 15:52:30 +08:00
|
|
|
coreFunction->pack = 16;
|
|
|
|
_AVX512_ReorderInit(coreFunction);
|
|
|
|
_AVX512_ExtraInit(coreFunction);
|
|
|
|
_AVX512_WinogradInit(coreFunction);
|
2021-06-11 17:17:13 +08:00
|
|
|
coreFunction->MNNPackForMatMul_B = _AVX512_MNNPackForMatMul_B;
|
|
|
|
coreFunction->MNNPackC4ForMatMul_A = _AVX512_MNNPackC8ForMatMul_A;
|
|
|
|
coreFunction->MNNPackedMatMul = _AVX512_MNNPackedMatMul;
|
|
|
|
coreFunction->MNNPackedMatMulRemain = _AVX512_MNNPackedMatMulRemain;
|
|
|
|
gPackInfo.eP = 48;
|
|
|
|
gPackInfo.hP = 8;
|
|
|
|
gPackInfo.lP = 1;
|
|
|
|
}
|
2021-09-18 15:52:30 +08:00
|
|
|
#ifdef MNN_AVX512_VNNI
|
|
|
|
if (cpuFlags & libyuv::kCpuHasAVX512VNNI) {
|
|
|
|
_AVX512_MNNInt8FunctionInit(gAVX2CoreInt8Functions);
|
|
|
|
}
|
|
|
|
#endif
|
2021-06-11 17:17:13 +08:00
|
|
|
#endif
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
CoreFunctions* AVX2Functions::get() {
|
|
|
|
return gAVX2CoreFunctions;
|
|
|
|
}
|
2021-09-18 15:52:30 +08:00
|
|
|
CoreInt8Functions* AVX2Functions::getInt8() {
|
|
|
|
return gAVX2CoreInt8Functions;
|
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
};
|