MNN/source/backend/cpu/x86_x64/FunctionDispatcher.cpp

162 lines
8.8 KiB
C++
Raw Normal View History

//
// FunctionDispatcher.cpp
// MNN
//
// Created by MNN on 2019/08/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-11-05 16:41:56 +08:00
#include <limits>
2021-01-05 15:30:28 +08:00
#include "avx512/FunctionSummary.hpp"
2020-11-05 16:41:56 +08:00
#include "avx/FunctionSummary.hpp"
2021-04-08 15:34:23 +08:00
#include "avxfma/FunctionSummary.hpp"
2019-12-27 22:16:57 +08:00
#include "backend/cpu/compute/CommonOptFunction.h"
#include "backend/cpu/compute/ConvOpt.h"
2020-02-26 09:57:17 +08:00
#include "backend/cpu/compute/Int8FunctionsOpt.h"
2020-07-04 01:21:30 +08:00
#include "cpu_id.h"
2020-11-05 16:41:56 +08:00
#include "sse/FunctionSummary.hpp"
// https://stackoverflow.com/a/11230437
#if defined(_MSC_VER)
#include <intrin.h>
#else
2020-04-10 14:44:01 +08:00
#include <x86intrin.h>
#endif
2020-07-04 01:21:30 +08:00
2020-04-10 14:44:01 +08:00
bool MNNReorder4x4ByPlatform(float* dst, size_t number) {
2020-12-15 14:12:35 +08:00
return _SSE_MNNReorder4x4ByPlatform(dst, number);
2020-04-10 14:44:01 +08:00
}
2020-07-04 01:21:30 +08:00
struct FunctionGroup {
2020-11-05 16:41:56 +08:00
int tileNumber = 8;
int eP = 12;
int lP = 1;
int hP = 4;
void (*MNNGemmInt8AddBiasScale_16x4_Unit)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit;
2021-02-07 10:45:07 +08:00
void (*MNNGemmInt8AddBiasScale_16x4_Unit_FAST)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit;
2020-11-05 16:41:56 +08:00
void (*MNNExpC8)(float* dest, const float* source, const float* parameters, size_t countC8) = _SSE_MNNExpC8;
void (*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
2021-01-06 16:29:37 +08:00
ssize_t maxValue, ssize_t zeroPoint) = _SSE_MNNFloat2Int8;
void (*MNNInt8ScaleToFloat)(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) = _SSE_MNNInt8ScaleToFloat;
void (*MNNLineDepthWiseInt8AddBiasScaleUnit)(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) = _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit;
void (*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) = _SSE_MNNComputeMatMulForE_1;
2021-04-08 15:34:23 +08:00
void (*MNNReluWithSlopeChannel)(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) = _SSE_MNNReluWithSlopeChannel;
void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size) = _SSE_MNNReluInt8;
void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish;
2020-07-04 01:21:30 +08:00
};
static FunctionGroup gFunc;
2021-04-08 15:34:23 +08:00
void _SSEMNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
*eP = gFunc.eP;
*lP = gFunc.lP;
*hP = gFunc.hP;
}
2020-07-04 01:21:30 +08:00
void MNNFunctionInit() {
auto cpuFlags = libyuv::InitCpuFlags();
2021-04-08 15:34:23 +08:00
auto coreFunction = MNN::MNNGetCoreFunctions();
if (cpuFlags & libyuv::kCpuHasSSSE3) {
coreFunction->MNNGetMatMulPackMode = _SSEMNNGetMatMulPackMode;
coreFunction->MNNMatrixAdd = _SSE_MNNMatrixAdd;
coreFunction->MNNMatrixSub = _SSE_MNNMatrixSub;
coreFunction->MNNPackedMatMul = _SSE_MNNPackedMatMul;
coreFunction->MNNPackedMatMulRemain = _SSE_MNNPackedMatMulRemain;
coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A;
coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B;
coreFunction->MNNConvRunForLineDepthwise = _SSE_MNNConvRunForLineDepthwise;
coreFunction->MNNAxByClampBroadcastUnit = _SSE_MNNAxByClampBroadcastUnit;
}
2021-01-06 16:29:37 +08:00
if (cpuFlags & libyuv::kCpuHasAVX2) {
2020-11-05 16:41:56 +08:00
gFunc.eP = 24;
2021-04-08 15:34:23 +08:00
gFunc.lP = 1;
gFunc.hP = 4;
coreFunction->MNNMatrixAdd = _AVX_MNNMatrixAdd;
coreFunction->MNNMatrixSub = _AVX_MNNMatrixSub;
coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMul;
coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemain;
coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A;
coreFunction->MNNConvRunForLineDepthwise = _AVX_MNNConvRunForLineDepthwise;
coreFunction->MNNAxByClampBroadcastUnit = _AVX_MNNAxByClampBroadcastUnit;
2020-11-05 16:41:56 +08:00
gFunc.MNNGemmInt8AddBiasScale_16x4_Unit = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit;
gFunc.MNNExpC8 = _AVX_MNNExpC8;
gFunc.MNNFloat2Int8 = _AVX_MNNFloat2Int8;
gFunc.MNNInt8ScaleToFloat = _AVX_MNNInt8ScaleToFloat;
2021-01-06 16:29:37 +08:00
gFunc.MNNLineDepthWiseInt8AddBiasScaleUnit = _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit;
gFunc.MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1;
2021-02-07 10:45:07 +08:00
gFunc.MNNGemmInt8AddBiasScale_16x4_Unit_FAST = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast;
2021-04-08 15:34:23 +08:00
gFunc.MNNReluWithSlopeChannel = _AVX_MNNReluWithSlopeChannel;
2020-07-04 01:21:30 +08:00
if (cpuFlags & libyuv::kCpuHasFMA3) {
2021-04-08 15:34:23 +08:00
coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMulFMA;
coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemainFMA;
2021-01-06 16:29:37 +08:00
gFunc.MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1FMA;
2020-07-04 01:21:30 +08:00
}
}
2021-01-06 16:29:37 +08:00
#ifdef MNN_AVX512
2021-02-07 10:45:07 +08:00
if (cpuFlags & libyuv::kCpuHasAVX512VNNI) {
2021-04-08 15:34:23 +08:00
coreFunction->MNNPackForMatMul_B = _AVX512_MNNPackForMatMul_B;
coreFunction->MNNPackC4ForMatMul_A = _AVX512_MNNPackC4ForMatMul_A;
coreFunction->MNNPackedMatMul = _AVX512_MNNPackedMatMul;
coreFunction->MNNPackedMatMulRemain = _AVX512_MNNPackedMatMulRemain;
gFunc.eP = 24;
gFunc.hP = 4;
gFunc.lP = 4;
2021-01-06 16:29:37 +08:00
gFunc.MNNGemmInt8AddBiasScale_16x4_Unit = _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit;
2021-02-07 10:45:07 +08:00
gFunc.MNNGemmInt8AddBiasScale_16x4_Unit_FAST = _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit;
2021-01-06 16:29:37 +08:00
}
#endif
2020-07-04 01:21:30 +08:00
}
// ========= CommonOptFunction.cpp ===========
void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
_SSE_MNNCopyC4WithStride(source, dest, srcStride, dstStride, count);
}
void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
_SSE_MNNAddC4WithStride(source, dest, srcStride, dstStride, count);
}
2020-11-05 16:41:56 +08:00
void MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
2021-04-08 15:34:23 +08:00
return gFunc.MNNReluWithSlopeChannel(dst, src, slope, sizeQuad, depthQuad);
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size) {
return gFunc.MNNReluInt8(dst, src, size);
2020-11-05 16:41:56 +08:00
}
2021-04-08 15:34:23 +08:00
void MNNHardSwish(float* dst, const float* src, size_t size) {
return gFunc.MNNHardSwish(dst, src, size);
}
void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue,
2021-01-06 16:29:37 +08:00
ssize_t maxValue, ssize_t zeroPoint) {
return gFunc.MNNFloat2Int8(src, dst, sizeQuad, scalep, minValue, maxValue, zeroPoint);
}
2021-01-06 16:29:37 +08:00
void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint) {
return gFunc.MNNInt8ScaleToFloat(dst, src, scale, size, zeroPoint);
}
2020-11-05 16:41:56 +08:00
void MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8) {
gFunc.MNNExpC8(dest, source, parameters, countC8);
}
void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
return gFunc.MNNGemmInt8AddBiasScale_16x4_Unit(dst, src, weight, src_depth_quad, dst_step, dst_depth_quad, post, realDst);
2020-11-05 16:41:56 +08:00
}
2021-01-06 16:29:37 +08:00
void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
gFunc.MNNLineDepthWiseInt8AddBiasScaleUnit(dst, src, weight, parameters, width, src_w_step, fw, fh, dilateX_step, dilateY_step);
}
void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) {
_SSE_MNNInt8ToInt16(dest, source, count);
}
void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) {
gFunc.MNNComputeMatMulForE_1(A, B, C, biasPtr, param, tId);
}
2021-02-07 10:45:07 +08:00
void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
gFunc.MNNGemmInt8AddBiasScale_16x4_Unit_FAST(dst, src, weight, src_depth_quad, dst_step, dst_depth_quad, post, realCount);
}