mirror of https://github.com/alibaba/MNN.git
344 lines
13 KiB
C++
344 lines
13 KiB
C++
//
|
|
// MathFunctions.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on b'2021/07/09'.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <emmintrin.h>
|
|
#include <string.h>
|
|
#include <algorithm>
|
|
#include <math.h>
|
|
#include "core/Macro.h"
|
|
#include "FunctionSummary.hpp"
|
|
|
|
void _SSE_MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) {
|
|
auto count = countC8 * 2;
|
|
auto A = _mm_set1_ps(offset[0]);
|
|
auto B = _mm_set1_ps(offset[1]);
|
|
auto p0 = _mm_set1_ps(parameters[0]);
|
|
auto p1 = _mm_set1_ps(parameters[1]);
|
|
auto p2 = _mm_set1_ps(parameters[2]);
|
|
auto p3 = _mm_set1_ps(parameters[3]);
|
|
auto p4 = _mm_set1_ps(parameters[4]);
|
|
auto p5 = _mm_set1_ps(parameters[5]);
|
|
auto p6 = _mm_set1_ps(parameters[6]);
|
|
auto p7 = _mm_set1_ps(parameters[7]);
|
|
auto xMax = _mm_set1_ps(87);
|
|
auto xMin = _mm_set1_ps(-87);
|
|
// auto basic = _mm_set1_epi32(1 << 23);
|
|
for (int i = 0; i < count; ++i) {
|
|
auto x = _mm_mul_ps(_mm_loadu_ps(source + i * 4), A);
|
|
x = _mm_max_ps(x, xMin);
|
|
x = _mm_min_ps(x, xMax);
|
|
auto div = _mm_mul_ps(x, p1);
|
|
auto divInt = _mm_cvtps_epi32(div);
|
|
div = _mm_cvtepi32_ps(divInt);
|
|
auto div2 = _mm_add_epi32(divInt, _mm_set1_epi32(127));
|
|
// div2 = _mm_mullo_epi32(div2, basic);
|
|
div2 = _mm_slli_epi32(div2, 23);
|
|
auto expBasic = _mm_castsi128_ps(div2);
|
|
auto xReamin = _mm_sub_ps(x, _mm_mul_ps(div, p0));
|
|
auto t = xReamin;
|
|
auto c0 = _mm_mul_ps(p7, t);
|
|
auto c1 = _mm_add_ps(c0, p6);
|
|
auto c2 = _mm_mul_ps(c1, t);
|
|
auto c3 = _mm_add_ps(c2, p5);
|
|
auto c4 = _mm_mul_ps(c3, t);
|
|
auto c5 = _mm_add_ps(c4, p4);
|
|
auto c6 = _mm_mul_ps(c5, t);
|
|
auto c7 = _mm_add_ps(c6, p3);
|
|
auto c8 = _mm_mul_ps(c7, t);
|
|
auto c9 = _mm_add_ps(c8, p2);
|
|
auto expRemain = c9;
|
|
_mm_storeu_ps(dest + 4 * i, _mm_add_ps(_mm_mul_ps(expBasic, expRemain), B));
|
|
}
|
|
}
|
|
|
|
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) {
|
|
float tmpfloat4[4];
|
|
int count = static_cast<int32_t>(size / 4);
|
|
int remain = count * 4;
|
|
// step 1: get maxValue
|
|
float maxValue = source[0];
|
|
if (count > 0) {
|
|
auto maxVal = _mm_loadu_ps(source);
|
|
for (int i = 1; i < count; i++) {
|
|
maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4));
|
|
}
|
|
_mm_storeu_ps(tmpfloat4, maxVal);
|
|
maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1];
|
|
maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2];
|
|
maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3];
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
maxValue = maxValue > source[i] ? maxValue : source[i];
|
|
}
|
|
|
|
// step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
|
|
float sumValue = 0.f;
|
|
if (count > 0) {
|
|
auto sumVal = _mm_set1_ps(0.f);
|
|
auto p0 = _mm_set1_ps(0.6931471805599453);
|
|
auto p1 = _mm_set1_ps(1.4426950408889634);
|
|
auto p2 = _mm_set1_ps(1.f);
|
|
auto p3 = _mm_set1_ps(1.f);
|
|
auto p4 = _mm_set1_ps(0.5);
|
|
auto p5 = _mm_set1_ps(0.1666666666666666);
|
|
auto p6 = _mm_set1_ps(0.041666666666666664);
|
|
auto p7 = _mm_set1_ps(0.008333333333333333);
|
|
auto xMax = _mm_set1_ps(87);
|
|
auto xMin = _mm_set1_ps(-87);
|
|
// auto basic = _mm_set1_epi32(1 << 23);
|
|
for (int i = 0; i < count; ++i) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(source + i * 4), _mm_set1_ps(maxValue));
|
|
x = _mm_max_ps(x, xMin);
|
|
x = _mm_min_ps(x, xMax);
|
|
auto div = _mm_mul_ps(x, p1);
|
|
auto divInt = _mm_cvtps_epi32(div);
|
|
div = _mm_cvtepi32_ps(divInt);
|
|
auto div2 = _mm_add_epi32(divInt, _mm_set1_epi32(127));
|
|
// div2 = _mm_mullo_epi32(div2, basic);
|
|
div2 = _mm_slli_epi32(div2, 23);
|
|
auto expBasic = _mm_castsi128_ps(div2);
|
|
auto xReamin = _mm_sub_ps(x, _mm_mul_ps(div, p0));
|
|
auto t = xReamin;
|
|
auto c0 = _mm_mul_ps(p7, t);
|
|
auto c1 = _mm_add_ps(c0, p6);
|
|
auto c2 = _mm_mul_ps(c1, t);
|
|
auto c3 = _mm_add_ps(c2, p5);
|
|
auto c4 = _mm_mul_ps(c3, t);
|
|
auto c5 = _mm_add_ps(c4, p4);
|
|
auto c6 = _mm_mul_ps(c5, t);
|
|
auto c7 = _mm_add_ps(c6, p3);
|
|
auto c8 = _mm_mul_ps(c7, t);
|
|
auto c9 = _mm_add_ps(c8, p2);
|
|
auto expRemain = c9;
|
|
auto expRes = _mm_mul_ps(expBasic, expRemain);
|
|
sumVal = _mm_add_ps(expRes, sumVal);
|
|
_mm_storeu_ps(dest + 4 * i, expRes);
|
|
}
|
|
_mm_storeu_ps(tmpfloat4, sumVal);
|
|
sumValue = tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3];
|
|
}
|
|
auto param = 0.6931471805599453;
|
|
float xLimit = 87;
|
|
for (int i = remain; i < size; i++) {
|
|
auto x = source[i] - maxValue;
|
|
x = x > -xLimit ? x : -xLimit;
|
|
x = x < xLimit ? x : xLimit;
|
|
|
|
int div = (x / param);
|
|
int div2 = (div + 127) << 23;
|
|
auto xReamin = x - div * param;
|
|
float expBasic = *(float*)(&div2);
|
|
|
|
auto t = xReamin;
|
|
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
|
|
dest[i] = expBasic * expRemain;
|
|
sumValue += dest[i];
|
|
}
|
|
// step 3: get x / sum and store
|
|
for (int i = 0; i < count; ++i) {
|
|
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
|
|
auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i));
|
|
auto y = _mm_set1_ps(sumValue);
|
|
auto z = _mm_rcp_ps(_mm_mul_ps(x, y));
|
|
_mm_storeu_ps(dest + 4 * i, z);
|
|
}
|
|
sumValue = 1.f / sumValue;
|
|
for (int i = remain; i < size; i++) {
|
|
dest[i] *= sumValue;
|
|
}
|
|
}
|
|
|
|
void _SSE_MNNGelu(float* dst, const float* src, size_t size, float* parameters) {
|
|
// parameters[8] = {0.044715f, 0.79788458f, 378.f, 17325.f, 135135.f, 28.f, 3150.f, 62370.f};
|
|
auto var1 = _mm_set1_ps(parameters[0]);
|
|
auto var2 = _mm_set1_ps(parameters[1]);
|
|
auto var3 = _mm_set1_ps(parameters[2]);
|
|
auto var4 = _mm_set1_ps(parameters[3]);
|
|
auto var5 = _mm_set1_ps(parameters[4]);
|
|
auto var6 = _mm_set1_ps(parameters[5]);
|
|
auto var7 = _mm_set1_ps(parameters[6]);
|
|
auto var8 = _mm_set1_ps(parameters[7]);
|
|
auto var9 = _mm_set1_ps(parameters[4]);
|
|
auto var10 = _mm_set1_ps(0.5);
|
|
auto varOne = _mm_set1_ps(1.f);
|
|
auto varNegOne = _mm_set1_ps(-1.f);
|
|
for (int i = 0; i < size * 2; i++) {
|
|
auto x = _mm_loadu_ps(src + i * 4);
|
|
auto y = _mm_mul_ps(x, x);
|
|
y = _mm_mul_ps(y, x);
|
|
y = _mm_mul_ps(y, var1);
|
|
y = _mm_add_ps(y, x);
|
|
y = _mm_mul_ps(y, var2);
|
|
// y = tanh(y)
|
|
{
|
|
auto y2 = _mm_mul_ps(y, y);
|
|
auto w = _mm_add_ps(y2, var3);
|
|
w = _mm_mul_ps(w, y2);
|
|
w = _mm_add_ps(w, var4);
|
|
w = _mm_mul_ps(w, y2);
|
|
w = _mm_add_ps(w, var5);
|
|
w = _mm_mul_ps(w, y);
|
|
auto z = _mm_mul_ps(y2, var6);
|
|
z = _mm_add_ps(z, var7);
|
|
z = _mm_mul_ps(z, y2);
|
|
z = _mm_add_ps(z, var8);
|
|
z = _mm_mul_ps(z, y2);
|
|
z = _mm_add_ps(z, var9);
|
|
z = _mm_div_ps(w, z);
|
|
z = _mm_max_ps(z, varNegOne);
|
|
y = _mm_min_ps(z, varOne);
|
|
}
|
|
y = _mm_add_ps(y, varOne);
|
|
y = _mm_mul_ps(y, x);
|
|
y = _mm_mul_ps(y, var10);
|
|
_mm_storeu_ps(dst + i * 4, y);
|
|
}
|
|
}
|
|
|
|
void _SSE_MNNHardSwish(float* dst, const float* src, size_t size) {
|
|
auto zero = _mm_set1_ps(0.f);
|
|
auto three = _mm_set1_ps(3.f);
|
|
auto six = _mm_set1_ps(6.f);
|
|
for (int i = 0; i < size; i++) {
|
|
auto x = _mm_loadu_ps(src + 4 * i);
|
|
_mm_storeu_ps(dst + 4 * i, _mm_div_ps(_mm_mul_ps(x, _mm_min_ps(_mm_max_ps(_mm_add_ps(x, three), zero), six)), six));
|
|
}
|
|
}
|
|
|
|
void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size) {
|
|
float tmpfloat4[4];
|
|
int count = static_cast<int32_t>(size / 4);
|
|
int remain = count * 4;
|
|
// step 1: get sum
|
|
float sum = 0.f;
|
|
if (count > 0) {
|
|
auto sumVal = _mm_set1_ps(0.f);
|
|
for (int i = 0; i < count; i++) {
|
|
sumVal = _mm_add_ps(sumVal, _mm_loadu_ps(src + i * 4));
|
|
}
|
|
_mm_storeu_ps(tmpfloat4, sumVal);
|
|
sum += (tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3]);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
sum += src[i];
|
|
}
|
|
// step 2: get square_sum
|
|
float mean = sum / size;
|
|
float square_sum = 0.f;
|
|
auto meanVal = _mm_set1_ps(mean);
|
|
if (count > 0) {
|
|
auto sumVal = _mm_set1_ps(0.f);
|
|
for (int i = 0; i < count; i++) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(src + i * 4), meanVal);
|
|
sumVal = _mm_add_ps(sumVal, _mm_mul_ps(x, x));
|
|
}
|
|
_mm_storeu_ps(tmpfloat4, sumVal);
|
|
square_sum += (tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3]);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
float x = (src[i] - mean);
|
|
square_sum += x * x;
|
|
}
|
|
// step 3: get result
|
|
float variable = square_sum / size;
|
|
variable = 1.f / sqrt(variable + epsilon);
|
|
auto variableVal = _mm_set1_ps(variable);
|
|
if (gamma && beta) {
|
|
for (int i = 0; i < count; i++) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(src + i * 4), meanVal);
|
|
auto g = _mm_loadu_ps(gamma + i * 4);
|
|
auto b = _mm_loadu_ps(beta + i * 4);
|
|
auto y = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(x, g), variableVal), b);
|
|
_mm_storeu_ps(dst + i * 4, y);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
dst[i] = (src[i] - mean) * gamma[i] * variable + beta[i] ;
|
|
}
|
|
} else {
|
|
for (int i = 0; i < count; i++) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(src + i * 4), meanVal);
|
|
auto y = _mm_mul_ps(x, variableVal);
|
|
_mm_storeu_ps(dst + i * 4, y);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
dst[i] = (src[i] - mean) * variable;
|
|
}
|
|
}
|
|
}
|
|
|
|
void _SSE_MNNNormInt8(int8_t* dst, const int8_t* src, const float* gamma, const float* beta, float epsilon, size_t size, QuanPrePostParameters* params) {
|
|
float tmpfloat4[4];
|
|
int count = static_cast<int32_t>(size / 4);
|
|
int remain = count * 4;
|
|
float sum = 0.f;
|
|
std::vector<float> inpf(size);
|
|
std::vector<float> outf(size);
|
|
std::vector<float> inpScale(4, params->inputScale[0]);
|
|
std::vector<float> outScale(4, params->outputScale[0]);
|
|
float* srcf = inpf.data();
|
|
float* dstf = outf.data();
|
|
// step 0: Int8 -> Float
|
|
_SSE_MNNInt8ScaleToFloat(inpf.data(), src, inpScale.data(), size / 4, params->inputZeroPoint[0]);
|
|
// step 1: get sum
|
|
if (count > 0) {
|
|
auto sumVal = _mm_set1_ps(0.f);
|
|
for (int i = 0; i < count; i++) {
|
|
sumVal = _mm_add_ps(sumVal, _mm_loadu_ps(srcf + i * 4));
|
|
}
|
|
_mm_storeu_ps(tmpfloat4, sumVal);
|
|
sum += (tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3]);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
sum += srcf[i];
|
|
}
|
|
// step 2: get square_sum
|
|
float mean = sum / size;
|
|
float square_sum = 0.f;
|
|
auto meanVal = _mm_set1_ps(mean);
|
|
if (count > 0) {
|
|
auto sumVal = _mm_set1_ps(0.f);
|
|
for (int i = 0; i < count; i++) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(srcf + i * 4), meanVal);
|
|
sumVal = _mm_add_ps(sumVal, _mm_mul_ps(x, x));
|
|
}
|
|
_mm_storeu_ps(tmpfloat4, sumVal);
|
|
square_sum += (tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3]);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
float x = (srcf[i] - mean);
|
|
square_sum += x * x;
|
|
}
|
|
// step 3: get result
|
|
float variable = square_sum / size;
|
|
variable = 1.f / sqrt(variable + epsilon);
|
|
auto variableVal = _mm_set1_ps(variable);
|
|
if (gamma && beta) {
|
|
for (int i = 0; i < count; i++) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(srcf + i * 4), meanVal);
|
|
auto g = _mm_loadu_ps(gamma + i * 4);
|
|
auto b = _mm_loadu_ps(beta + i * 4);
|
|
auto y = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(x, g), variableVal), b);
|
|
_mm_storeu_ps(dstf + i * 4, y);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
dstf[i] = (src[i] - mean) * gamma[i] * variable + beta[i] ;
|
|
}
|
|
} else {
|
|
for (int i = 0; i < count; i++) {
|
|
auto x = _mm_sub_ps(_mm_loadu_ps(srcf + i * 4), meanVal);
|
|
auto y = _mm_mul_ps(x, variableVal);
|
|
_mm_storeu_ps(dstf + i * 4, y);
|
|
}
|
|
for (int i = remain; i < size; i++) {
|
|
dstf[i] = (srcf[i] - mean) * variable;
|
|
}
|
|
}
|
|
// step 4: Float -> Int8
|
|
_SSE_MNNFloat2Int8(dstf, dst, size / 4, outScale.data(), params->minValue, params->maxValue, params->outputZeroPoint[0]);
|
|
}
|