mirror of https://github.com/alibaba/MNN.git
216 lines
9.7 KiB
Metal
216 lines
9.7 KiB
Metal
//
|
|
// MetalConvolution.metal
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2018/08/22.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <metal_stdlib>
|
|
#include "MetalConvolutionActivation.metal"
|
|
|
|
using namespace metal;
|
|
using namespace MNN;
|
|
|
|
#define CONV_UNROLL (4)
|
|
|
|
kernel void conv_quantize(const device ftype4 *in [[buffer(0)]],
|
|
device char4 *out [[buffer(1)]],
|
|
constant float& scale [[buffer(2)]],
|
|
constant int2& range [[buffer(3)]],
|
|
uint gid [[thread_position_in_grid]]) {
|
|
// ftype4 -> int4 -> char4 : right
|
|
// ftype4 -> char4 : wrong
|
|
int4 qnt = int4(round(float4(in[int(gid)]) * scale));
|
|
out[int(gid)] = char4(clamp(qnt, range.x, range.y));
|
|
}
|
|
|
|
struct conv_constants {
|
|
int input_width;
|
|
int input_height;
|
|
int input_size;
|
|
int input_slice;
|
|
int output_width;
|
|
int output_height;
|
|
int output_size;
|
|
int output_slice;
|
|
int threadgroup_input_slice;
|
|
|
|
int kernel_x;
|
|
int kernel_y;
|
|
int kernel_size;
|
|
int stride_x;
|
|
int stride_y;
|
|
int pad_x;
|
|
int pad_y;
|
|
int dilation_x;
|
|
int dilation_y;
|
|
conv_activation_type activation;
|
|
};
|
|
|
|
kernel void conv(const device ftype4 *in [[buffer(0)]],
|
|
device ftype4 *out [[buffer(1)]],
|
|
constant conv_constants& cst [[buffer(2)]],
|
|
const device ftype4x4 *wt [[buffer(3)]],
|
|
const device ftype4 *biasTerms [[buffer(4)]],
|
|
uint3 gid [[thread_position_in_grid]]) {
|
|
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;
|
|
|
|
int offset_x = (int)gid.x * cst.stride_x - cst.pad_x;
|
|
int offset_y = (int)gid.y * cst.stride_y - cst.pad_y;
|
|
int sx = max(0, (UP_DIV(-offset_x, cst.dilation_x)));
|
|
int ex = min(cst.kernel_x, UP_DIV(cst.input_width - offset_x, cst.dilation_x));
|
|
short kw = ex - sx;
|
|
int sy = max(0, (UP_DIV(-offset_y, cst.dilation_y)));
|
|
int ey = min(cst.kernel_y, UP_DIV(cst.input_height - offset_y, cst.dilation_y));
|
|
short kh = ey - sy;
|
|
offset_x += sx * cst.dilation_x;
|
|
offset_y += sy * cst.dilation_y;
|
|
|
|
auto z_in = in + offset_y * cst.input_width + offset_x;
|
|
auto z_wt = wt + (int)gid.z * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
|
|
auto z_out = out + (int)gid.z * cst.output_size + (int)gid.y * cst.output_width + (int)gid.x;
|
|
|
|
int dilation_h = cst.input_width * cst.dilation_y;
|
|
float4 result = float4(biasTerms[(short)gid.z]);
|
|
for (auto z = 0; z < cst.input_slice; z++) {
|
|
for (auto y = 0; y < kh; y++) {
|
|
for (auto x = 0; x < kw; x++) {
|
|
auto wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x + x];
|
|
auto in4 = z_in[z * cst.input_size + y * dilation_h + x * cst.dilation_x];
|
|
result += float4(in4 * wt4);
|
|
}
|
|
}
|
|
}
|
|
*z_out = activate(ftype4(result), cst.activation);
|
|
}
|
|
|
|
kernel void conv_z4(const device ftype4 *in [[buffer(0)]],
|
|
device ftype4 *out [[buffer(1)]],
|
|
constant conv_constants& cst [[buffer(2)]],
|
|
const device ftype4x4 *wt [[buffer(3)]],
|
|
const device ftype4 *biasTerms [[buffer(4)]],
|
|
uint3 gid [[thread_position_in_grid]]) {
|
|
if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;
|
|
|
|
int4 uz = gid.z * CONV_UNROLL + int4(0, 1, 2, 3);
|
|
bool3 valids = uz.yzw < cst.output_slice;
|
|
|
|
int offset_x = (int)gid.x * cst.stride_x - cst.pad_x;
|
|
int offset_y = (int)gid.y * cst.stride_y - cst.pad_y;
|
|
int sx = max(0, (UP_DIV(-offset_x, cst.dilation_x)));
|
|
int ex = min(cst.kernel_x, UP_DIV(cst.input_width - offset_x, cst.dilation_x));
|
|
short kw = ex - sx;
|
|
int sy = max(0, (UP_DIV(-offset_y, cst.dilation_y)));
|
|
int ey = min(cst.kernel_y, UP_DIV(cst.input_height - offset_y, cst.dilation_y));
|
|
short kh = ey - sy;
|
|
offset_x += sx * cst.dilation_x;
|
|
offset_y += sy * cst.dilation_y;
|
|
|
|
auto z_in = in + offset_y * cst.input_width + offset_x;
|
|
auto z_wt = wt + uz[0] * cst.input_slice * cst.kernel_size + sy * cst.kernel_x + sx;
|
|
auto z_out = out + uz[0] * cst.output_size + (int)gid.y * cst.output_width + (int)gid.x;
|
|
|
|
int ws = cst.input_slice * cst.kernel_size;
|
|
int dilation_h = cst.input_width * cst.dilation_y;
|
|
float4 result0 = 0, result1 = 0, result2 = 0, result3 = 0;
|
|
for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size) {
|
|
for (auto y = 0; y < kh; y++) {
|
|
for (auto x = 0; x < kw; x++) {
|
|
auto x_wt = z_wt + y * cst.kernel_x + x;
|
|
auto in4 = z_in[ y * dilation_h + x * cst.dilation_x];
|
|
/* true */ result0 += float4(in4 * *x_wt);
|
|
if (valids[0]) { x_wt += ws; result1 += float4(in4 * *x_wt); }
|
|
if (valids[1]) { x_wt += ws; result2 += float4(in4 * *x_wt); }
|
|
if (valids[2]) { x_wt += ws; result3 += float4(in4 * *x_wt); }
|
|
}
|
|
}
|
|
}
|
|
/* true */ *z_out = activate(ftype4(result0 + float4(biasTerms[uz[0]])), cst.activation);
|
|
if (valids[0]) { z_out += cst.output_size; *z_out = activate(ftype4(result1 + float4(biasTerms[uz[1]])), cst.activation); }
|
|
if (valids[1]) { z_out += cst.output_size; *z_out = activate(ftype4(result2 + float4(biasTerms[uz[2]])), cst.activation); }
|
|
if (valids[2]) { z_out += cst.output_size; *z_out = activate(ftype4(result3 + float4(biasTerms[uz[3]])), cst.activation); }
|
|
}
|
|
|
|
kernel void conv_local(const device ftype4 *in [[buffer(0)]],
|
|
device ftype4 *out [[buffer(1)]],
|
|
constant conv_constants& cst [[buffer(2)]],
|
|
const device ftype4x4 *wt [[buffer(3)]],
|
|
const device ftype4 *biasTerms [[buffer(4)]],
|
|
threadgroup ftype4x4 *cols [[threadgroup(0)]],
|
|
ushort3 gid [[thread_position_in_grid]],
|
|
ushort3 tid [[thread_position_in_threadgroup]],
|
|
ushort3 thread_size [[threads_per_threadgroup]]) {
|
|
short unroll_x = CONV_UNROLL * gid.x;
|
|
short offset_x = unroll_x * cst.stride_x - cst.pad_x;
|
|
short offset_y = gid.y * cst.stride_y - cst.pad_y;
|
|
short sy = max(0, UP_DIV(-offset_y, cst.dilation_y));
|
|
short ey = min(cst.kernel_y, UP_DIV(cst.input_height - offset_y, cst.dilation_y));
|
|
auto o_wt = wt + (int)gid.z * cst.input_slice * cst.kernel_size;
|
|
|
|
float4x4 result = float4x4(0);
|
|
short steps = UP_DIV(cst.input_slice, cst.threadgroup_input_slice);
|
|
for (auto s = 0; s < steps; s++)
|
|
{
|
|
int sz_stt = s * cst.threadgroup_input_slice;
|
|
int sz_end = min(sz_stt + cst.threadgroup_input_slice, cst.input_slice);
|
|
int sz_size = sz_end - sz_stt;
|
|
|
|
// im2col
|
|
int z_step = UP_DIV(sz_size, (int)thread_size.z);
|
|
int z_stt = tid.z * z_step;
|
|
int z_end = min(z_stt + z_step, sz_size);
|
|
|
|
for (auto z = z_stt; z < z_end; z++) {
|
|
for (auto ky = sy; ky < ey; ky++) {
|
|
for (auto kx = 0; kx < cst.kernel_x; kx++) {
|
|
auto y_in = in
|
|
+ (z + sz_stt) * cst.input_size
|
|
+ (offset_y + ky * cst.dilation_y) * cst.input_width;
|
|
int4 x4 = offset_x + kx * cst.dilation_x + cst.stride_x * int4(0, 1, 2, 3);
|
|
bool4 valids = 0 <= x4 && x4 < cst.input_width;
|
|
cols[z * cst.kernel_size + ky * cst.kernel_x + kx] = {
|
|
valids[0] ? y_in[x4[0]] : 0,
|
|
valids[1] ? y_in[x4[1]] : 0,
|
|
valids[2] ? y_in[x4[2]] : 0,
|
|
valids[3] ? y_in[x4[3]] : 0
|
|
};
|
|
}
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// gemm
|
|
if ((short)gid.z < cst.output_slice) {
|
|
for (auto z = 0; z < sz_size; z++) {
|
|
for (auto ky = sy; ky < ey; ky++) {
|
|
for (auto kx = 0; kx < cst.kernel_x; kx++) {
|
|
auto in4 = cols[ z * cst.kernel_size + ky * cst.kernel_x + kx];
|
|
auto wt4 = o_wt[(z + sz_stt) * cst.kernel_size + ky * cst.kernel_x + kx];
|
|
result += {
|
|
float4(in4[0] * wt4),
|
|
float4(in4[1] * wt4),
|
|
float4(in4[2] * wt4),
|
|
float4(in4[3] * wt4)
|
|
};
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (s == steps - 1) break;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
} // end step
|
|
|
|
// save
|
|
if ((short)gid.z >= cst.output_slice) return;
|
|
|
|
float4 b4 = float4(biasTerms[(short)gid.z]);
|
|
auto off_out = out + (int)gid.z * cst.output_size + (int)gid.y * cst.output_width + unroll_x;
|
|
bool3 valids = (unroll_x + int3(1, 2, 3)) < cst.output_width;
|
|
/* true */ off_out[0] = activate((ftype4)(result[0] + b4), cst.activation);
|
|
if (valids[0]) off_out[1] = activate((ftype4)(result[1] + b4), cst.activation);
|
|
if (valids[1]) off_out[2] = activate((ftype4)(result[2] + b4), cst.activation);
|
|
if (valids[2]) off_out[3] = activate((ftype4)(result[3] + b4), cst.activation);
|
|
}
|