MNN/source/backend/cpu/x86_x64/sse/PackedFunction.cpp

190 lines
8.7 KiB
C++
Raw Normal View History

2021-09-18 15:52:30 +08:00
//
// PackedFunction.cpp
// MNN
//
// Created by MNN on b'2021/07/09'.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <emmintrin.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include "core/Macro.h"
#include "FunctionSummary.hpp"
void _SSE_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
auto minF = _mm_set1_ps(parameters[2]);
auto maxF = _mm_set1_ps(parameters[3]);
auto beta = _mm_set1_ps(parameters[1]);
for (int y = 0; y < height; ++y) {
auto a = A + aStride * y;
auto b = B + 4 * y;
auto bv = _mm_loadu_ps(b);
auto c = C + cStride * y;
for (int x = 0; x < width; ++x) {
auto av = _mm_loadu_ps(a + 4 * x);
auto cv = _mm_add_ps(av, _mm_mul_ps(bv, beta));
cv = _mm_min_ps(cv, maxF);
cv = _mm_max_ps(cv, minF);
_mm_storeu_ps(c + 4 * x, cv);
}
}
}
void _SSE_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
for (int i = 0; i < count; ++i) {
auto s = source + i * srcStride;
auto d = dest + i * dstStride;
_mm_storeu_ps(d, _mm_loadu_ps(s));
}
}
void _SSE_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
for (int i = 0; i < count; ++i) {
auto s = source + i * srcStride;
auto d = dest + i * dstStride;
_mm_storeu_ps(d, _mm_add_ps(_mm_loadu_ps(s), _mm_loadu_ps(d)));
}
}
void _SSE_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
auto zero = _mm_set1_ps(0.0f);
for (int j = 0; j < depthQuad; j++) {
auto slopeZ = _mm_loadu_ps(slope + 4 * j);
const float* srcZ = src + 4 * j * sizeQuad;
float* dstZ = dst + 4 * j * sizeQuad;
for (int i = 0; i < sizeQuad; i++) {
auto src = _mm_loadu_ps(srcZ + 4 * i);
auto mask0 = _mm_cmplt_ps(src, zero);
auto mask1 = _mm_cmpge_ps(src, zero);
auto other = _mm_mul_ps(src, slopeZ);
_mm_storeu_ps(dstZ + 4 * i, _mm_add_ps(_mm_and_ps(other, mask0), _mm_and_ps(src, mask1)));
}
}
}
void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
size_t srcHStep, size_t dstHStep) {
int dx, fx, fy;
const int unit = 8;
int widthUnit = width / unit;
int widthRemain = width - widthUnit * unit;
const float* weight_z = weight;
bool need4 = widthRemain >= 4;
if (need4) {
widthRemain-=4;
}
for (int y = 0; y < height; ++y) {
auto srcY = src + y * srcHStep;
auto dstY = dst + y * dstHStep;
for (dx = 0; dx < widthUnit; ++dx) {
auto dstValue0 = _mm_set1_ps(0.0f);
auto dstValue1 = _mm_set1_ps(0.0f);
auto dstValue2 = _mm_set1_ps(0.0f);
auto dstValue3 = _mm_set1_ps(0.0f);
auto dstValue4 = _mm_set1_ps(0.0f);
auto dstValue5 = _mm_set1_ps(0.0f);
auto dstValue6 = _mm_set1_ps(0.0f);
auto dstValue7 = _mm_set1_ps(0.0f);
for (fy = 0; fy < fh; ++fy) {
const float* src_y = srcY + fy * dilateY_step;
const float* weight_y = weight_z + fy * fw * 4;
for (fx = 0; fx < fw; ++fx) {
const float* src_x = src_y + fx * dilateX_step;
const float* weight_x = weight_y + 4 * fx;
auto weightValue = _mm_loadu_ps(weight_x);
dstValue0 = _mm_add_ps(dstValue0, _mm_mul_ps(_mm_loadu_ps(src_x + 0 * src_w_setup), weightValue));
dstValue1 = _mm_add_ps(dstValue1, _mm_mul_ps(_mm_loadu_ps(src_x + 1 * src_w_setup), weightValue));
dstValue2 = _mm_add_ps(dstValue2, _mm_mul_ps(_mm_loadu_ps(src_x + 2 * src_w_setup), weightValue));
dstValue3 = _mm_add_ps(dstValue3, _mm_mul_ps(_mm_loadu_ps(src_x + 3 * src_w_setup), weightValue));
dstValue4 = _mm_add_ps(dstValue4, _mm_mul_ps(_mm_loadu_ps(src_x + 4 * src_w_setup), weightValue));
dstValue5 = _mm_add_ps(dstValue5, _mm_mul_ps(_mm_loadu_ps(src_x + 5 * src_w_setup), weightValue));
dstValue6 = _mm_add_ps(dstValue6, _mm_mul_ps(_mm_loadu_ps(src_x + 6 * src_w_setup), weightValue));
dstValue7 = _mm_add_ps(dstValue7, _mm_mul_ps(_mm_loadu_ps(src_x + 7 * src_w_setup), weightValue));
}
}
_mm_storeu_ps(dstY + 4 * 0, dstValue0);
_mm_storeu_ps(dstY + 4 * 1, dstValue1);
_mm_storeu_ps(dstY + 4 * 2, dstValue2);
_mm_storeu_ps(dstY + 4 * 3, dstValue3);
_mm_storeu_ps(dstY + 4 * 4, dstValue4);
_mm_storeu_ps(dstY + 4 * 5, dstValue5);
_mm_storeu_ps(dstY + 4 * 6, dstValue6);
_mm_storeu_ps(dstY + 4 * 7, dstValue7);
dstY += 4 * unit;
srcY += unit * src_w_setup;
}
if (need4) {
auto dstValue0 = _mm_set1_ps(0.0f);
auto dstValue1 = _mm_set1_ps(0.0f);
auto dstValue2 = _mm_set1_ps(0.0f);
auto dstValue3 = _mm_set1_ps(0.0f);
for (fy = 0; fy < fh; ++fy) {
const float* src_y = srcY + fy * dilateY_step;
const float* weight_y = weight_z + fy * fw * 4;
for (fx = 0; fx < fw; ++fx) {
const float* src_x = src_y + fx * dilateX_step;
const float* weight_x = weight_y + 4 * fx;
auto weightValue = _mm_loadu_ps(weight_x);
dstValue0 = _mm_add_ps(dstValue0, _mm_mul_ps(_mm_loadu_ps(src_x + 0 * src_w_setup), weightValue));
dstValue1 = _mm_add_ps(dstValue1, _mm_mul_ps(_mm_loadu_ps(src_x + 1 * src_w_setup), weightValue));
dstValue2 = _mm_add_ps(dstValue2, _mm_mul_ps(_mm_loadu_ps(src_x + 2 * src_w_setup), weightValue));
dstValue3 = _mm_add_ps(dstValue3, _mm_mul_ps(_mm_loadu_ps(src_x + 3 * src_w_setup), weightValue));
}
}
_mm_storeu_ps(dstY + 4 * 0, dstValue0);
_mm_storeu_ps(dstY + 4 * 1, dstValue1);
_mm_storeu_ps(dstY + 4 * 2, dstValue2);
_mm_storeu_ps(dstY + 4 * 3, dstValue3);
dstY += 4 * 4;
srcY += 4 * src_w_setup;
}
for (dx = 0; dx < widthRemain; ++dx) {
float* dst_x = dstY + dx * 4;
auto dstValue = _mm_set1_ps(0.0f);
const float* src_z = srcY + src_w_setup * dx;
const float* weight_z = weight;
for (fy = 0; fy < fh; ++fy) {
const float* src_y = src_z + fy * dilateY_step;
const float* weight_y = weight_z + fy * fw * 4;
for (fx = 0; fx < fw; ++fx) {
const float* weight_x = weight_y + 4 * fx;
const float* src_x = src_y + fx * dilateX_step;
dstValue = _mm_add_ps(dstValue, _mm_mul_ps(_mm_loadu_ps(src_x), _mm_loadu_ps(weight_x)));
}
}
_mm_storeu_ps(dst_x, dstValue);
}
}
}
void _SSE_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
size_t bStride, size_t height) {
for (int y = 0; y < height; ++y) {
auto a = A + aStride * y;
auto b = B + bStride * y;
auto c = C + cStride * y;
for (int x = 0; x < widthC4; ++x) {
_mm_storeu_ps(c + 4 * x, _mm_sub_ps(_mm_loadu_ps(a + 4 * x), _mm_loadu_ps(b + 4 * x)));
}
}
}
void _SSE_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
size_t bStride, size_t height) {
for (int y = 0; y < height; ++y) {
auto a = A + aStride * y;
auto b = B + bStride * y;
auto c = C + cStride * y;
for (int x = 0; x < widthC4; ++x) {
_mm_storeu_ps(c + 4 * x, _mm_add_ps(_mm_loadu_ps(b + 4 * x), _mm_loadu_ps(a + 4 * x)));
}
}
}
void _SSE_ExtraInit(void* functions) {
auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
coreFunction->MNNMatrixAdd = _SSE_MNNMatrixAdd;
coreFunction->MNNMatrixSub = _SSE_MNNMatrixSub;
coreFunction->MNNConvRunForLineDepthwise = _SSE_MNNConvRunForLineDepthwise;
coreFunction->MNNAxByClampBroadcastUnit = _SSE_MNNAxByClampBroadcastUnit;
}