2021-09-18 15:52:30 +08:00
|
|
|
//
|
|
|
|
// PackedFunction.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on b'2021/07/09'.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <emmintrin.h>
|
|
|
|
#include <string.h>
|
|
|
|
#include <algorithm>
|
|
|
|
#include <math.h>
|
|
|
|
#include "core/Macro.h"
|
|
|
|
#include "FunctionSummary.hpp"
|
|
|
|
void _SSE_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
|
|
|
|
auto minF = _mm_set1_ps(parameters[2]);
|
|
|
|
auto maxF = _mm_set1_ps(parameters[3]);
|
|
|
|
auto beta = _mm_set1_ps(parameters[1]);
|
|
|
|
for (int y = 0; y < height; ++y) {
|
|
|
|
auto a = A + aStride * y;
|
|
|
|
auto b = B + 4 * y;
|
|
|
|
auto bv = _mm_loadu_ps(b);
|
|
|
|
auto c = C + cStride * y;
|
|
|
|
for (int x = 0; x < width; ++x) {
|
|
|
|
auto av = _mm_loadu_ps(a + 4 * x);
|
|
|
|
auto cv = _mm_add_ps(av, _mm_mul_ps(bv, beta));
|
|
|
|
cv = _mm_min_ps(cv, maxF);
|
|
|
|
cv = _mm_max_ps(cv, minF);
|
|
|
|
_mm_storeu_ps(c + 4 * x, cv);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void _SSE_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
auto s = source + i * srcStride;
|
|
|
|
auto d = dest + i * dstStride;
|
|
|
|
_mm_storeu_ps(d, _mm_loadu_ps(s));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void _SSE_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
auto s = source + i * srcStride;
|
|
|
|
auto d = dest + i * dstStride;
|
|
|
|
_mm_storeu_ps(d, _mm_add_ps(_mm_loadu_ps(s), _mm_loadu_ps(d)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void _SSE_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
|
|
|
|
auto zero = _mm_set1_ps(0.0f);
|
|
|
|
for (int j = 0; j < depthQuad; j++) {
|
|
|
|
auto slopeZ = _mm_loadu_ps(slope + 4 * j);
|
|
|
|
const float* srcZ = src + 4 * j * sizeQuad;
|
|
|
|
float* dstZ = dst + 4 * j * sizeQuad;
|
|
|
|
for (int i = 0; i < sizeQuad; i++) {
|
|
|
|
auto src = _mm_loadu_ps(srcZ + 4 * i);
|
|
|
|
auto mask0 = _mm_cmplt_ps(src, zero);
|
|
|
|
auto mask1 = _mm_cmpge_ps(src, zero);
|
|
|
|
auto other = _mm_mul_ps(src, slopeZ);
|
|
|
|
_mm_storeu_ps(dstZ + 4 * i, _mm_add_ps(_mm_and_ps(other, mask0), _mm_and_ps(src, mask1)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
|
|
|
|
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
|
2024-10-14 19:26:28 +08:00
|
|
|
size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters) {
|
2021-09-18 15:52:30 +08:00
|
|
|
int dx, fx, fy;
|
|
|
|
const int unit = 8;
|
|
|
|
int widthUnit = width / unit;
|
|
|
|
int widthRemain = width - widthUnit * unit;
|
|
|
|
const float* weight_z = weight;
|
|
|
|
bool need4 = widthRemain >= 4;
|
|
|
|
if (need4) {
|
|
|
|
widthRemain-=4;
|
|
|
|
}
|
2024-10-14 19:26:28 +08:00
|
|
|
auto minF = _mm_set1_ps(parameters[0]);
|
|
|
|
auto maxF = _mm_set1_ps(parameters[1]);
|
|
|
|
auto bv = _mm_loadu_ps(bias);
|
2021-09-18 15:52:30 +08:00
|
|
|
for (int y = 0; y < height; ++y) {
|
|
|
|
auto srcY = src + y * srcHStep;
|
|
|
|
auto dstY = dst + y * dstHStep;
|
|
|
|
for (dx = 0; dx < widthUnit; ++dx) {
|
2024-10-14 19:26:28 +08:00
|
|
|
auto dstValue0 = bv;
|
|
|
|
auto dstValue1 = bv;
|
|
|
|
auto dstValue2 = bv;
|
|
|
|
auto dstValue3 = bv;
|
|
|
|
auto dstValue4 = bv;
|
|
|
|
auto dstValue5 = bv;
|
|
|
|
auto dstValue6 = bv;
|
|
|
|
auto dstValue7 = bv;
|
2021-09-18 15:52:30 +08:00
|
|
|
for (fy = 0; fy < fh; ++fy) {
|
|
|
|
const float* src_y = srcY + fy * dilateY_step;
|
|
|
|
const float* weight_y = weight_z + fy * fw * 4;
|
|
|
|
for (fx = 0; fx < fw; ++fx) {
|
|
|
|
const float* src_x = src_y + fx * dilateX_step;
|
|
|
|
const float* weight_x = weight_y + 4 * fx;
|
|
|
|
auto weightValue = _mm_loadu_ps(weight_x);
|
|
|
|
dstValue0 = _mm_add_ps(dstValue0, _mm_mul_ps(_mm_loadu_ps(src_x + 0 * src_w_setup), weightValue));
|
|
|
|
dstValue1 = _mm_add_ps(dstValue1, _mm_mul_ps(_mm_loadu_ps(src_x + 1 * src_w_setup), weightValue));
|
|
|
|
dstValue2 = _mm_add_ps(dstValue2, _mm_mul_ps(_mm_loadu_ps(src_x + 2 * src_w_setup), weightValue));
|
|
|
|
dstValue3 = _mm_add_ps(dstValue3, _mm_mul_ps(_mm_loadu_ps(src_x + 3 * src_w_setup), weightValue));
|
|
|
|
dstValue4 = _mm_add_ps(dstValue4, _mm_mul_ps(_mm_loadu_ps(src_x + 4 * src_w_setup), weightValue));
|
|
|
|
dstValue5 = _mm_add_ps(dstValue5, _mm_mul_ps(_mm_loadu_ps(src_x + 5 * src_w_setup), weightValue));
|
|
|
|
dstValue6 = _mm_add_ps(dstValue6, _mm_mul_ps(_mm_loadu_ps(src_x + 6 * src_w_setup), weightValue));
|
|
|
|
dstValue7 = _mm_add_ps(dstValue7, _mm_mul_ps(_mm_loadu_ps(src_x + 7 * src_w_setup), weightValue));
|
|
|
|
}
|
|
|
|
}
|
2024-10-14 19:26:28 +08:00
|
|
|
dstValue0 = _mm_min_ps(dstValue0, maxF);
|
|
|
|
dstValue1 = _mm_min_ps(dstValue1, maxF);
|
|
|
|
dstValue2 = _mm_min_ps(dstValue2, maxF);
|
|
|
|
dstValue3 = _mm_min_ps(dstValue3, maxF);
|
|
|
|
dstValue4 = _mm_min_ps(dstValue4, maxF);
|
|
|
|
dstValue5 = _mm_min_ps(dstValue5, maxF);
|
|
|
|
dstValue6 = _mm_min_ps(dstValue6, maxF);
|
|
|
|
dstValue7 = _mm_min_ps(dstValue7, maxF);
|
|
|
|
|
|
|
|
dstValue0 = _mm_max_ps(dstValue0, minF);
|
|
|
|
dstValue1 = _mm_max_ps(dstValue1, minF);
|
|
|
|
dstValue2 = _mm_max_ps(dstValue2, minF);
|
|
|
|
dstValue3 = _mm_max_ps(dstValue3, minF);
|
|
|
|
dstValue4 = _mm_max_ps(dstValue4, minF);
|
|
|
|
dstValue5 = _mm_max_ps(dstValue5, minF);
|
|
|
|
dstValue6 = _mm_max_ps(dstValue6, minF);
|
|
|
|
dstValue7 = _mm_max_ps(dstValue7, minF);
|
|
|
|
|
2021-09-18 15:52:30 +08:00
|
|
|
_mm_storeu_ps(dstY + 4 * 0, dstValue0);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 1, dstValue1);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 2, dstValue2);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 3, dstValue3);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 4, dstValue4);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 5, dstValue5);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 6, dstValue6);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 7, dstValue7);
|
|
|
|
dstY += 4 * unit;
|
|
|
|
srcY += unit * src_w_setup;
|
|
|
|
}
|
|
|
|
if (need4) {
|
2024-10-14 19:26:28 +08:00
|
|
|
auto dstValue0 = bv;
|
|
|
|
auto dstValue1 = bv;
|
|
|
|
auto dstValue2 = bv;
|
|
|
|
auto dstValue3 = bv;
|
2021-09-18 15:52:30 +08:00
|
|
|
for (fy = 0; fy < fh; ++fy) {
|
|
|
|
const float* src_y = srcY + fy * dilateY_step;
|
|
|
|
const float* weight_y = weight_z + fy * fw * 4;
|
|
|
|
for (fx = 0; fx < fw; ++fx) {
|
|
|
|
const float* src_x = src_y + fx * dilateX_step;
|
|
|
|
const float* weight_x = weight_y + 4 * fx;
|
|
|
|
auto weightValue = _mm_loadu_ps(weight_x);
|
|
|
|
dstValue0 = _mm_add_ps(dstValue0, _mm_mul_ps(_mm_loadu_ps(src_x + 0 * src_w_setup), weightValue));
|
|
|
|
dstValue1 = _mm_add_ps(dstValue1, _mm_mul_ps(_mm_loadu_ps(src_x + 1 * src_w_setup), weightValue));
|
|
|
|
dstValue2 = _mm_add_ps(dstValue2, _mm_mul_ps(_mm_loadu_ps(src_x + 2 * src_w_setup), weightValue));
|
|
|
|
dstValue3 = _mm_add_ps(dstValue3, _mm_mul_ps(_mm_loadu_ps(src_x + 3 * src_w_setup), weightValue));
|
|
|
|
}
|
|
|
|
}
|
2024-10-14 19:26:28 +08:00
|
|
|
dstValue0 = _mm_min_ps(dstValue0, maxF);
|
|
|
|
dstValue1 = _mm_min_ps(dstValue1, maxF);
|
|
|
|
dstValue2 = _mm_min_ps(dstValue2, maxF);
|
|
|
|
dstValue3 = _mm_min_ps(dstValue3, maxF);
|
|
|
|
|
|
|
|
dstValue0 = _mm_max_ps(dstValue0, minF);
|
|
|
|
dstValue1 = _mm_max_ps(dstValue1, minF);
|
|
|
|
dstValue2 = _mm_max_ps(dstValue2, minF);
|
|
|
|
dstValue3 = _mm_max_ps(dstValue3, minF);
|
2021-09-18 15:52:30 +08:00
|
|
|
_mm_storeu_ps(dstY + 4 * 0, dstValue0);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 1, dstValue1);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 2, dstValue2);
|
|
|
|
_mm_storeu_ps(dstY + 4 * 3, dstValue3);
|
|
|
|
dstY += 4 * 4;
|
|
|
|
srcY += 4 * src_w_setup;
|
|
|
|
}
|
|
|
|
for (dx = 0; dx < widthRemain; ++dx) {
|
|
|
|
float* dst_x = dstY + dx * 4;
|
2024-10-14 19:26:28 +08:00
|
|
|
auto dstValue = bv;
|
2021-09-18 15:52:30 +08:00
|
|
|
const float* src_z = srcY + src_w_setup * dx;
|
|
|
|
const float* weight_z = weight;
|
|
|
|
for (fy = 0; fy < fh; ++fy) {
|
|
|
|
const float* src_y = src_z + fy * dilateY_step;
|
|
|
|
const float* weight_y = weight_z + fy * fw * 4;
|
|
|
|
for (fx = 0; fx < fw; ++fx) {
|
|
|
|
const float* weight_x = weight_y + 4 * fx;
|
|
|
|
const float* src_x = src_y + fx * dilateX_step;
|
|
|
|
dstValue = _mm_add_ps(dstValue, _mm_mul_ps(_mm_loadu_ps(src_x), _mm_loadu_ps(weight_x)));
|
|
|
|
}
|
|
|
|
}
|
2024-10-14 19:26:28 +08:00
|
|
|
dstValue = _mm_min_ps(dstValue, maxF);
|
|
|
|
dstValue = _mm_max_ps(dstValue, minF);
|
2021-09-18 15:52:30 +08:00
|
|
|
_mm_storeu_ps(dst_x, dstValue);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void _SSE_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
|
|
|
|
size_t bStride, size_t height) {
|
|
|
|
for (int y = 0; y < height; ++y) {
|
|
|
|
auto a = A + aStride * y;
|
|
|
|
auto b = B + bStride * y;
|
|
|
|
auto c = C + cStride * y;
|
|
|
|
for (int x = 0; x < widthC4; ++x) {
|
|
|
|
_mm_storeu_ps(c + 4 * x, _mm_sub_ps(_mm_loadu_ps(a + 4 * x), _mm_loadu_ps(b + 4 * x)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void _SSE_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
|
|
|
|
size_t bStride, size_t height) {
|
|
|
|
for (int y = 0; y < height; ++y) {
|
|
|
|
auto a = A + aStride * y;
|
|
|
|
auto b = B + bStride * y;
|
|
|
|
auto c = C + cStride * y;
|
|
|
|
for (int x = 0; x < widthC4; ++x) {
|
|
|
|
_mm_storeu_ps(c + 4 * x, _mm_add_ps(_mm_loadu_ps(b + 4 * x), _mm_loadu_ps(a + 4 * x)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void _SSE_ExtraInit(void* functions) {
|
|
|
|
auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
|
|
|
|
coreFunction->MNNMatrixAdd = _SSE_MNNMatrixAdd;
|
|
|
|
coreFunction->MNNMatrixSub = _SSE_MNNMatrixSub;
|
|
|
|
coreFunction->MNNConvRunForLineDepthwise = _SSE_MNNConvRunForLineDepthwise;
|
|
|
|
coreFunction->MNNAxByClampBroadcastUnit = _SSE_MNNAxByClampBroadcastUnit;
|
|
|
|
}
|