MNN/source/backend/cpu/x86_x64/avx/GemmInt8.cpp

761 lines
34 KiB
C++

//
// GemmInt8.cpp
// MNN
//
// Created by MNN on 2020/09/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "GemmCommon.hpp"
#include "FunctionSummary.hpp"
#include "core/Macro.h"
#include <math.h>
#define AVX2_PACKINT8 8
#define GEMMINT8_AVX2_E 4
#define GEMMINT8_AVX2_L 4
#define GEMMINT8_AVX2_H 8
namespace {
static inline __m128i mm_loadu_si128(const void* addr) {
return _mm_loadu_si128((__m128i const*)addr);
}
static inline void MNN__mm_storeu_si64(void* add, __m128i value) {
float temp[4];
_mm_storeu_ps(temp, _mm_castsi128_ps(value));
::memcpy(add, temp, sizeof(int64_t));
}
} // namespace
#define POSTTREAT(N) \
f##N = _mm256_min_ps(f##N, maxValue);\
f##N = _mm256_max_ps(f##N, minValue);\
auto m##N = _mm256_cmp_ps(f##N, zero128, 1);\
m##N = _mm256_blendv_ps(plus, minus, m##N);\
f##N = _mm256_add_ps(f##N, m##N);\
D##N = _mm256_cvtps_epi32(_mm256_round_ps(f##N, 3));\
D##N = _mm256_add_epi32(D##N, offset);\
D##N = _mm256_packs_epi32(D##N, _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(D##N), _mm256_castsi256_ps(D##N), 1)));\
auto d##N = _mm_packus_epi16(_mm256_castsi256_si128(D##N), _mm256_castsi256_si128(_mm256_castps_si256(zero128)));\
MNN__mm_storeu_si64(dst_x + N * 8, d##N);
inline __m256i NORMAL_HADD(__m256i x, __m256i y) {
auto c0 = _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(x), _mm256_castsi256_ps(y), 32));
auto c1 = _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(x), _mm256_castsi256_ps(y), 49));
return _mm256_hadd_epi32(c0, c1);
}
#define EXTRACT_ADD(i)\
auto d##i##0 = _mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(D##i), 0));\
auto d##i##1 = _mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(D##i), 1));\
auto d##i = _mm_add_epi32(d##i##0, d##i##1);
#define COMPUTE(u, v)\
D##u##v = _mm256_add_epi32(D##u##v, _mm256_madd_epi16(W##u, S##v));
void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
const auto dst_step_tmp = dst_step / sizeof(int8_t);
auto zero128 = _mm256_set1_ps(0.0f);
auto minValue = _mm256_set1_ps(post->minValue);
auto maxValue = _mm256_set1_ps(post->maxValue);
auto plus = _mm256_set1_ps(0.5f);
auto minus = _mm256_set1_ps(-0.5f);
auto offset = _mm256_set1_epi32(128);
//printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad);
if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
__m256i D03 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
__m256i D11 = _mm256_set1_epi32(0);
__m256i D12 = _mm256_set1_epi32(0);
__m256i D13 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2));
auto s3 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 3));
auto S0 = _mm256_cvtepu8_epi16(s0);
auto S1 = _mm256_cvtepu8_epi16(s1);
auto S2 = _mm256_cvtepu8_epi16(s2);
auto S3 = _mm256_cvtepu8_epi16(s3);
COMPUTE(0, 0);
COMPUTE(1, 0);
COMPUTE(0, 1);
COMPUTE(1, 1);
COMPUTE(0, 2);
COMPUTE(1, 2);
COMPUTE(0, 3);
COMPUTE(1, 3);
}
auto D0 = NORMAL_HADD(D00, D10);
auto D1 = NORMAL_HADD(D01, D11);
auto D2 = NORMAL_HADD(D02, D12);
auto D3 = NORMAL_HADD(D03, D13);
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
D1 = _mm256_add_epi32(D1, biasValue0);
D2 = _mm256_add_epi32(D2, biasValue0);
D3 = _mm256_add_epi32(D3, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
auto f3 = _mm256_cvtepi32_ps(D3);
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
f3 = _mm256_mul_ps(f3, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
} else {
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
POSTTREAT(3);
}
}
return;
}
if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
__m256i D11 = _mm256_set1_epi32(0);
__m256i D12 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2));
auto S0 = _mm256_cvtepu8_epi16(s0);
auto S1 = _mm256_cvtepu8_epi16(s1);
auto S2 = _mm256_cvtepu8_epi16(s2);
COMPUTE(0, 0);
COMPUTE(1, 0);
COMPUTE(0, 1);
COMPUTE(1, 1);
COMPUTE(0, 2);
COMPUTE(1, 2);
}
auto D0 = NORMAL_HADD(D00, D10);
auto D1 = NORMAL_HADD(D01, D11);
auto D2 = NORMAL_HADD(D02, D12);
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
D1 = _mm256_add_epi32(D1, biasValue0);
D2 = _mm256_add_epi32(D2, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} else {
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
}
}
return;
}
if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
__m256i D11 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
auto S0 = _mm256_cvtepu8_epi16(s0);
auto S1 = _mm256_cvtepu8_epi16(s1);
COMPUTE(0, 0);
COMPUTE(1, 0);
COMPUTE(0, 1);
COMPUTE(1, 1);
}
auto D0 = NORMAL_HADD(D00, D10);
auto D1 = NORMAL_HADD(D01, D11);
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
D1 = _mm256_add_epi32(D1, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} else {
POSTTREAT(0);
POSTTREAT(1);
}
}
return;
}
if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto S0 = _mm256_cvtepu8_epi16(s0);
COMPUTE(0, 0);
COMPUTE(1, 0);
}
auto D0 = NORMAL_HADD(D00, D10);
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
f0 = _mm256_mul_ps(f0, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
} else {
POSTTREAT(0);
}
}
return;
}
}
void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
const auto dst_step_tmp = dst_step / sizeof(int8_t);
auto zero128 = _mm256_set1_ps(0.0f);
auto minValue = _mm256_set1_ps(post->minValue);
auto maxValue = _mm256_set1_ps(post->maxValue);
auto plus = _mm256_set1_ps(0.5f);
auto minus = _mm256_set1_ps(-0.5f);
auto oneValue = _mm256_set1_epi16(1);
auto offset = _mm256_set1_epi32(128);
//printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad);
if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
__m256i D03 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));
auto s3 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 3));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
D03 = _mm256_add_epi32(D03, _mm256_madd_epi16(_mm256_maddubs_epi16(s3, w0), oneValue));
}
auto D0 = D00;
auto D1 = D01;
auto D2 = D02;
auto D3 = D03;
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
D1 = _mm256_add_epi32(D1, biasValue0);
D2 = _mm256_add_epi32(D2, biasValue0);
D3 = _mm256_add_epi32(D3, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
auto f3 = _mm256_cvtepi32_ps(D3);
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
f3 = _mm256_mul_ps(f3, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
} else {
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
POSTTREAT(3);
}
}
return;
}
if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
}
auto D0 = D00;
auto D1 = D01;
auto D2 = D02;
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
D1 = _mm256_add_epi32(D1, biasValue0);
D2 = _mm256_add_epi32(D2, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} else {
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
}
}
return;
}
if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
}
auto D0 = D00;
auto D1 = D01;
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
D1 = _mm256_add_epi32(D1, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} else {
POSTTREAT(0);
POSTTREAT(1);
}
}
return;
}
if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto bias_dz = post->bias + dz * AVX2_PACKINT8;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
}
auto D0 = D00;
auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
D0 = _mm256_add_epi32(D0, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
f0 = _mm256_mul_ps(f0, scaleValue);
if (post->useInt8 == 0) {
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
} else {
POSTTREAT(0);
}
}
return;
}
}
#undef MAIN_COMPUTE
#undef STORE_TEMP
void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder) {
int pack = 16;
auto dst = dstO;
auto src = (const int16_t*)srcO;
auto weight = (const int16_t*)weightO;
auto biasValue0 = _mm256_castps_si256(_mm256_loadu_ps((const float*)parameters->bias));
auto biasValue1 = _mm256_castps_si256(_mm256_loadu_ps((const float*)parameters->bias + 8));
auto scaleValue0 = _mm256_loadu_ps((const float*)parameters->scale);
auto scaleValue1 = _mm256_loadu_ps((const float*)parameters->scale + 8);
__m256i srcValue0;
__m256i zero = _mm256_xor_si256(srcValue0, srcValue0);
__m256i d0, d1;
int dx, fx, fy;
__m256 zero256 = _mm256_set1_ps(0.0f);
auto minValue = _mm256_set1_epi16((int16_t)(parameters->minValue + 128));
auto maxValue = _mm256_set1_epi16((int16_t)(parameters->maxValue + 128));
__m256 plus = _mm256_set1_ps(0.5f);
__m256 minus = _mm256_set1_ps(-0.5f);
auto offset = _mm256_set1_epi32(128);
for (dx = 0; dx < width; ++dx) {
d0 = biasValue0;
d1 = biasValue1;
auto dst_x = dst;
const auto src_z = src;
for (fy = 0; fy < fh; ++fy) {
const auto src_y = src_z + fy * dilateY_step;
const auto weight_y = weight + fy * fw * pack;
for (fx = 0; fx < fw; ++fx) {
const auto src_x = src_y + fx * dilateX_step;
auto s0_16 = _mm256_castps_si256(_mm256_loadu_ps((float*)src_x));
s0_16 = _mm256_permute4x64_epi64(s0_16, 0xD8); // Reorder 0,1,2,3->0,2,1,3 to ensure s0_32 is 0,1 and s1_32 is 2,3.
auto s0_32 = _mm256_unpacklo_epi16(s0_16, zero);
auto s1_32 = _mm256_unpackhi_epi16(s0_16, zero);
const auto weight_x = weight_y + pack * fx;
auto w0_16 = _mm256_castps_si256(_mm256_loadu_ps((float*)weight_x));
w0_16 = _mm256_permute4x64_epi64(w0_16, 0xD8);
auto w0_32 = _mm256_unpacklo_epi16(w0_16, zero);
auto w1_32 = _mm256_unpackhi_epi16(w0_16, zero);
d0 = _mm256_add_epi32(d0, _mm256_madd_epi16(w0_32, s0_32));
d1 = _mm256_add_epi32(d1, _mm256_madd_epi16(w1_32, s1_32));
}
}
__m256 f0 = _mm256_cvtepi32_ps(d0);
__m256 f1 = _mm256_cvtepi32_ps(d1);
f0 = _mm256_mul_ps(f0, scaleValue0);
f1 = _mm256_mul_ps(f1, scaleValue1);
auto m0 = _mm256_cmp_ps(f0, zero256, 1);
auto m1 = _mm256_cmp_ps(f1, zero256, 1);
m0 = _mm256_blendv_ps(plus, minus, m0);
m1 = _mm256_blendv_ps(plus, minus, m1);
f0 = _mm256_add_ps(f0, m0);
f1 = _mm256_add_ps(f1, m1);
// _MM_FROUND_TO_ZERO
d0 = _mm256_cvtps_epi32(_mm256_round_ps(f0, 3));
d1 = _mm256_cvtps_epi32(_mm256_round_ps(f1, 3));
d0 = _mm256_add_epi32(d0, offset);
d1 = _mm256_add_epi32(d1, offset);
d0 = _mm256_permute4x64_epi64(_mm256_packs_epi32(d0, d1), 0xD8);
d0 = _mm256_min_epi16(d0, maxValue);
d0 = _mm256_max_epi16(d0, minValue);
auto y256i = _mm256_permute4x64_epi64(_mm256_packus_epi16(d0, _mm256_setzero_si256()), 0xD8);
auto y128 = _mm_castsi128_ps(_mm256_extracti128_si256(y256i, 0));
_mm_storeu_ps((float*)dst, y128);
dst += 16;
src += src_w_step;
}
}
void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint) {
auto zero = _mm256_set1_epi32(0);
auto minValue = _mm256_set1_ps(minV);
auto maxValue = _mm256_set1_ps(maxV);
auto zeroPointValue = _mm256_set1_ps(zeroPoint);
auto offset = _mm256_set1_epi32(128);
auto plus = _mm256_set1_ps(0.5f);
auto minus = _mm256_set1_ps(-0.5f);
auto scaleValue = _mm256_loadu_ps(scalep);
for (int i = 0; i < sizeQuad; ++i) {
auto f0 = _mm256_loadu_ps(src + 8 * i);
f0 = _mm256_mul_ps(f0, scaleValue);
f0 = _mm256_add_ps(f0, zeroPointValue);
f0 = _mm256_min_ps(f0, maxValue);
f0 = _mm256_max_ps(f0, minValue);
auto m0 = _mm256_cmp_ps(f0, _mm256_castsi256_ps(zero), 1);
m0 = _mm256_blendv_ps(plus, minus, m0);
f0 = _mm256_add_ps(f0, m0);
// 3: _MM_FROUND_TO_ZERO
auto d0 = _mm256_cvtps_epi32(_mm256_round_ps(f0, 3));
d0 = _mm256_add_epi32(d0, offset);
d0 = _mm256_packs_epi32(d0, _mm256_setzero_si256());
d0 = _mm256_permute4x64_epi64(d0, 0xD8);
#if defined(_MSC_VER)
__m256i x = static_cast<__m256i>(_mm256_packus_epi16(d0, _mm256_setzero_si256()));
*((int64_t*)dst + i) = x.m256i_i64[0];
#else
__v4di x = static_cast<__v4di>(_mm256_packus_epi16(d0, _mm256_setzero_si256()));
*((int64_t*)dst + i) = x[0];
#endif
}
}
void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, ssize_t zeroPoint) {
auto sizeC4 = sizeQuad / 4;
auto sizeRemain = sizeQuad % 4;
auto zero = _mm256_set1_epi32(0);
auto scaleValue = _mm256_loadu_ps(scale);
auto zeroPointValue = _mm256_set1_epi32(zeroPoint + 128);
for (int i = 0; i < sizeC4; ++i) {
auto s = _mm256_castps_si256(_mm256_loadu_ps((const float*)(src)));
auto s0_16 = _mm256_permute4x64_epi64(_mm256_unpacklo_epi8(s, zero), 0XD8);
auto s1_16 = _mm256_permute4x64_epi64(_mm256_unpackhi_epi8(s, zero), 0xD8);
auto s0_32 = _mm256_unpacklo_epi16(s0_16, zero);
auto s1_32 = _mm256_unpacklo_epi16(s1_16, zero);
auto s2_32 = _mm256_unpackhi_epi16(s0_16, zero);
auto s3_32 = _mm256_unpackhi_epi16(s1_16, zero);
s0_32 = _mm256_sub_epi32(s0_32, zeroPointValue);
s1_32 = _mm256_sub_epi32(s1_32, zeroPointValue);
s2_32 = _mm256_sub_epi32(s2_32, zeroPointValue);
s3_32 = _mm256_sub_epi32(s3_32, zeroPointValue);
auto s0_f = _mm256_cvtepi32_ps(s0_32);
auto s1_f = _mm256_cvtepi32_ps(s1_32);
auto s2_f = _mm256_cvtepi32_ps(s2_32);
auto s3_f = _mm256_cvtepi32_ps(s3_32);
_mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue));
_mm256_storeu_ps(dst + 8 * 1, _mm256_mul_ps(s1_f, scaleValue));
_mm256_storeu_ps(dst + 8 * 2, _mm256_mul_ps(s2_f, scaleValue));
_mm256_storeu_ps(dst + 8 * 3, _mm256_mul_ps(s3_f, scaleValue));
src += 32;
dst += 32;
}
if (sizeRemain > 0) {
int8_t srcTemp[256];
::memcpy(srcTemp, src, sizeRemain * 8);
auto s = _mm256_castps_si256(_mm256_loadu_ps((const float*)(srcTemp)));
auto s0_16 = _mm256_permute4x64_epi64(_mm256_unpacklo_epi8(s, zero), 0XD8);
auto s1_16 = _mm256_permute4x64_epi64(_mm256_unpackhi_epi8(s, zero), 0xD8);
auto s0_32 = _mm256_unpacklo_epi16(s0_16, zero);
auto s1_32 = _mm256_unpacklo_epi16(s1_16, zero);
auto s2_32 = _mm256_unpackhi_epi16(s0_16, zero);
auto s3_32 = _mm256_unpackhi_epi16(s1_16, zero);
s0_32 = _mm256_sub_epi32(s0_32, zeroPointValue);
s1_32 = _mm256_sub_epi32(s1_32, zeroPointValue);
s2_32 = _mm256_sub_epi32(s2_32, zeroPointValue);
s3_32 = _mm256_sub_epi32(s3_32, zeroPointValue);
auto s0_f = _mm256_cvtepi32_ps(s0_32);
auto s1_f = _mm256_cvtepi32_ps(s1_32);
auto s2_f = _mm256_cvtepi32_ps(s2_32);
auto s3_f = _mm256_cvtepi32_ps(s3_32);
switch (sizeRemain) {
case 3:
_mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue));
_mm256_storeu_ps(dst + 8 * 1, _mm256_mul_ps(s1_f, scaleValue));
_mm256_storeu_ps(dst + 8 * 2, _mm256_mul_ps(s2_f, scaleValue));
break;
case 2:
_mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue));
_mm256_storeu_ps(dst + 8 * 1, _mm256_mul_ps(s1_f, scaleValue));
break;
case 1:
_mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue));
break;
default:
break;
}
}
}
static void _AVX2_MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
*UNIT = GEMMINT8_AVX2_H;
*SRC_UNIT = GEMMINT8_AVX2_L;
*DST_XUNIT = GEMMINT8_AVX2_E;
}
static void _AVXMNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) {
int number = info[0];
int eReal = info[1];
int xStride = info[3];
int xS4 = xStride * AVX2_PACKINT8 / sizeof(int32_t);
int PUNIT = AVX2_PACKINT8 / GEMMINT8_AVX2_L;
int FLOATPACK = AVX2_PACKINT8 / sizeof(int32_t);
int eOutsideStride = info[2] / sizeof(int32_t);
const int EP = GEMMINT8_AVX2_E;
int eDest = EP;
const int LP = GEMMINT8_AVX2_L;
for (int n=0; n<number; ++n) {
int e = el[4 * n + 0];
int l = el[4 * n + 1];
int eOffset = el[4 * n + 2];
int lOffset = el[4 * n + 3];
int eC = eOffset / eDest;
int eR = eOffset % eDest;
auto source = (int32_t*)sourceGroup[n];
auto dest = (int32_t*)(destOrigin + eC * info[2] + eR * LP + lOffset * EP);
//printf("e=%d, l=%d, eOffset=%d, lOffset=%d, eDest=%d\n", e, l, eOffset, lOffset, eDest);
l = l / 4; // Use float instead of int8 * 4
int eS = eDest - eR;
for (int x = 0; x < l; ++x) {
int eRemain = e;
auto xR = x % PUNIT;
auto xC = x / PUNIT;
auto d = dest + x * eDest;
auto s = source + xC * eReal * FLOATPACK + xR;
if (eR > 0) {
int eStep = ALIMIN(eRemain, eS);
for (int yi=0; yi<eStep; ++yi) {
d[yi] = s[yi * xS4];
}
eRemain-=eStep;
d += (eOutsideStride - eR);
s += eS * xS4;
}
while (eRemain > 0) {
int eStep = ALIMIN(eDest, eRemain);
for (int yi=0; yi<eStep; ++yi) {
d[yi] = s[yi * xS4];
}
eRemain-=eStep;
d+= eOutsideStride;
s+= eStep * xS4;
}
}
}
}
void _AVX_MNNInt8FunctionInit(void* functions) {
auto gAVX2CoreInt8Functions = (MNN::CoreInt8Functions*)functions;
// MatMul
gAVX2CoreInt8Functions->Int8GemmKernel = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit;
gAVX2CoreInt8Functions->Int8GemmKernelFast = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast;
gAVX2CoreInt8Functions->MNNGetGemmUnit = _AVX2_MNNGetGemmUnit;
gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _AVXMNNPackC4ForMatMul_A;
// Int8 <-> Float
gAVX2CoreInt8Functions->MNNFloat2Int8 = _AVX_MNNFloat2Int8;
gAVX2CoreInt8Functions->MNNInt8ScaleToFloat = _AVX_MNNInt8ScaleToFloat;
// conv depthwise
gAVX2CoreInt8Functions->ConvDepthwiseLineInt8 = _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit;
// Norm
gAVX2CoreInt8Functions->MNNNormInt8 = _AVX_MNNNormInt8;
}