mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			216 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Metal
		
	
	
	
		
		
			
		
	
	
			216 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Metal
		
	
	
	
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								//  MetalConvolution.metal
							 | 
						||
| 
								 | 
							
								//  MNN
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								//  Created by MNN on 2018/08/22.
							 | 
						||
| 
								 | 
							
								//  Copyright © 2018, Alibaba Group Holding Limited
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#include <metal_stdlib>
							 | 
						||
| 
								 | 
							
								#include "MetalConvolutionActivation.metal"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								using namespace metal;
							 | 
						||
| 
								 | 
							
								using namespace MNN;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#define CONV_UNROLL (4)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								kernel void conv_quantize(const device ftype4 *in   [[buffer(0)]],
							 | 
						||
| 
								 | 
							
								                          device char4 *out         [[buffer(1)]],
							 | 
						||
| 
								 | 
							
								                          constant float& scale     [[buffer(2)]],
							 | 
						||
| 
								 | 
							
								                          constant int2& range      [[buffer(3)]],
							 | 
						||
| 
								 | 
							
								                          uint gid                  [[thread_position_in_grid]]) {
							 | 
						||
| 
								 | 
							
								    // ftype4 -> int4 -> char4   : right
							 | 
						||
| 
								 | 
							
								    // ftype4 -> char4           : wrong
							 | 
						||
| 
								 | 
							
								    int4 qnt = int4(round(float4(in[int(gid)]) * scale));
							 | 
						||
| 
								 | 
							
								    out[int(gid)] = char4(clamp(qnt, range.x, range.y));
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								struct conv_constants {
							 | 
						||
| 
								 | 
							
								    int input_width;
							 | 
						||
| 
								 | 
							
								    int input_height;
							 | 
						||
| 
								 | 
							
								    int input_size;
							 | 
						||
| 
								 | 
							
								    int input_slice;
							 | 
						||
| 
								 | 
							
								    int output_width;
							 | 
						||
| 
								 | 
							
								    int output_height;
							 | 
						||
| 
								 | 
							
								    int output_size;
							 | 
						||
| 
								 | 
							
								    int output_slice;
							 | 
						||
| 
								 | 
							
								    int threadgroup_input_slice;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    int kernel_x;
							 | 
						||
| 
								 | 
							
								    int kernel_y;
							 | 
						||
| 
								 | 
							
								    int kernel_size;
							 | 
						||
| 
								 | 
							
								    int stride_x;
							 | 
						||
| 
								 | 
							
								    int stride_y;
							 | 
						||
| 
								 | 
							
								    int pad_x;
							 | 
						||
| 
								 | 
							
								    int pad_y;
							 | 
						||
| 
								 | 
							
								    int dilation_x;
							 | 
						||
| 
								 | 
							
								    int dilation_y;
							 | 
						||
| 
								 | 
							
								    conv_activation_type activation;    
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								kernel void conv(const device ftype4 *in        [[buffer(0)]],
							 | 
						||
| 
								 | 
							
								                 device ftype4 *out             [[buffer(1)]],
							 | 
						||
| 
								 | 
							
								                 constant conv_constants& cst   [[buffer(2)]],
							 | 
						||
| 
								 | 
							
								                 const device ftype4x4 *wt      [[buffer(3)]],
							 | 
						||
| 
								 | 
							
								                 const device ftype4 *biasTerms [[buffer(4)]],
							 | 
						||
| 
								 | 
							
								                 uint3 gid                      [[thread_position_in_grid]]) {
							 | 
						||
| 
								 | 
							
								    if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    int offset_x = (int)gid.x * cst.stride_x - cst.pad_x;
							 | 
						||
| 
								 | 
							
								    int offset_y = (int)gid.y * cst.stride_y - cst.pad_y;
							 | 
						||
| 
								 | 
							
								    int sx = max(0, (UP_DIV(-offset_x, cst.dilation_x)));
							 | 
						||
| 
								 | 
							
								    int ex = min(cst.kernel_x, UP_DIV(cst.input_width - offset_x, cst.dilation_x));
							 | 
						||
| 
								 | 
							
								    short kw = ex - sx;
							 | 
						||
| 
								 | 
							
								    int sy = max(0, (UP_DIV(-offset_y, cst.dilation_y)));
							 | 
						||
| 
								 | 
							
								    int ey = min(cst.kernel_y, UP_DIV(cst.input_height - offset_y, cst.dilation_y));
							 | 
						||
| 
								 | 
							
								    short kh = ey - sy;
							 | 
						||
| 
								 | 
							
								    offset_x += sx * cst.dilation_x;
							 | 
						||
| 
								 | 
							
								    offset_y += sy * cst.dilation_y;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    auto z_in  = in                                                   + offset_y * cst.input_width    + offset_x;
							 | 
						||
| 
								 | 
							
								    auto z_wt  = wt  + (int)gid.z * cst.input_slice * cst.kernel_size + sy * cst.kernel_x             + sx;
							 | 
						||
| 
								 | 
							
								    auto z_out = out + (int)gid.z * cst.output_size                   + (int)gid.y * cst.output_width + (int)gid.x;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    int dilation_h = cst.input_width * cst.dilation_y;
							 | 
						||
| 
								 | 
							
								    float4 result = float4(biasTerms[(short)gid.z]);
							 | 
						||
| 
								 | 
							
								    for (auto z = 0; z < cst.input_slice; z++) {
							 | 
						||
| 
								 | 
							
								        for (auto y = 0; y < kh; y++) {
							 | 
						||
| 
								 | 
							
								            for (auto x = 0; x < kw; x++) {
							 | 
						||
| 
								 | 
							
								                auto wt4 = z_wt[z * cst.kernel_size + y * cst.kernel_x + x];
							 | 
						||
| 
								 | 
							
								                auto in4 = z_in[z * cst.input_size  + y * dilation_h   + x * cst.dilation_x];
							 | 
						||
| 
								 | 
							
								                result += float4(in4 * wt4);
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    *z_out = activate(ftype4(result), cst.activation);
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								kernel void conv_z4(const device ftype4 *in         [[buffer(0)]],
							 | 
						||
| 
								 | 
							
								                    device ftype4 *out              [[buffer(1)]],
							 | 
						||
| 
								 | 
							
								                    constant conv_constants& cst    [[buffer(2)]],
							 | 
						||
| 
								 | 
							
								                    const device ftype4x4 *wt       [[buffer(3)]],
							 | 
						||
| 
								 | 
							
								                    const device ftype4 *biasTerms  [[buffer(4)]],
							 | 
						||
| 
								 | 
							
								                    uint3 gid                       [[thread_position_in_grid]]) {
							 | 
						||
| 
								 | 
							
								    if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    int4 uz = gid.z * CONV_UNROLL + int4(0, 1, 2, 3);
							 | 
						||
| 
								 | 
							
								    bool3 valids = uz.yzw < cst.output_slice;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    int offset_x = (int)gid.x * cst.stride_x - cst.pad_x;
							 | 
						||
| 
								 | 
							
								    int offset_y = (int)gid.y * cst.stride_y - cst.pad_y;
							 | 
						||
| 
								 | 
							
								    int sx = max(0, (UP_DIV(-offset_x, cst.dilation_x)));
							 | 
						||
| 
								 | 
							
								    int ex = min(cst.kernel_x, UP_DIV(cst.input_width - offset_x, cst.dilation_x));
							 | 
						||
| 
								 | 
							
								    short kw = ex - sx;
							 | 
						||
| 
								 | 
							
								    int sy = max(0, (UP_DIV(-offset_y, cst.dilation_y)));
							 | 
						||
| 
								 | 
							
								    int ey = min(cst.kernel_y, UP_DIV(cst.input_height - offset_y, cst.dilation_y));
							 | 
						||
| 
								 | 
							
								    short kh = ey - sy;
							 | 
						||
| 
								 | 
							
								    offset_x += sx * cst.dilation_x;
							 | 
						||
| 
								 | 
							
								    offset_y += sy * cst.dilation_y;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    auto z_in  = in                                              + offset_y * cst.input_width    + offset_x;
							 | 
						||
| 
								 | 
							
								    auto z_wt  = wt  + uz[0] * cst.input_slice * cst.kernel_size + sy * cst.kernel_x             + sx;
							 | 
						||
| 
								 | 
							
								    auto z_out = out + uz[0] * cst.output_size                   + (int)gid.y * cst.output_width + (int)gid.x;
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    int ws = cst.input_slice * cst.kernel_size;
							 | 
						||
| 
								 | 
							
								    int dilation_h = cst.input_width * cst.dilation_y;
							 | 
						||
| 
								 | 
							
								    float4 result0 = 0, result1 = 0, result2 = 0, result3 = 0;
							 | 
						||
| 
								 | 
							
								    for (auto z = 0; z < cst.input_slice; z++, z_wt += cst.kernel_size, z_in += cst.input_size) {
							 | 
						||
| 
								 | 
							
								        for (auto y = 0; y < kh; y++) {
							 | 
						||
| 
								 | 
							
								            for (auto x = 0; x < kw; x++) {
							 | 
						||
| 
								 | 
							
								                auto x_wt = z_wt + y * cst.kernel_x + x;
							 | 
						||
| 
								 | 
							
								                auto in4  = z_in[  y * dilation_h   + x * cst.dilation_x];
							 | 
						||
| 
								 | 
							
								                /* true                   */ result0 += float4(in4 * *x_wt);
							 | 
						||
| 
								 | 
							
								                if (valids[0]) { x_wt += ws; result1 += float4(in4 * *x_wt); }
							 | 
						||
| 
								 | 
							
								                if (valids[1]) { x_wt += ws; result2 += float4(in4 * *x_wt); }
							 | 
						||
| 
								 | 
							
								                if (valids[2]) { x_wt += ws; result3 += float4(in4 * *x_wt); }
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    /* true                                 */ *z_out = activate(ftype4(result0 + float4(biasTerms[uz[0]])), cst.activation);
							 | 
						||
| 
								 | 
							
								    if (valids[0]) { z_out += cst.output_size; *z_out = activate(ftype4(result1 + float4(biasTerms[uz[1]])), cst.activation); }
							 | 
						||
| 
								 | 
							
								    if (valids[1]) { z_out += cst.output_size; *z_out = activate(ftype4(result2 + float4(biasTerms[uz[2]])), cst.activation); }
							 | 
						||
| 
								 | 
							
								    if (valids[2]) { z_out += cst.output_size; *z_out = activate(ftype4(result3 + float4(biasTerms[uz[3]])), cst.activation); }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								kernel void conv_local(const device ftype4 *in          [[buffer(0)]],
							 | 
						||
| 
								 | 
							
								                       device ftype4 *out               [[buffer(1)]],
							 | 
						||
| 
								 | 
							
								                       constant conv_constants& cst     [[buffer(2)]],
							 | 
						||
| 
								 | 
							
								                       const device ftype4x4 *wt        [[buffer(3)]],
							 | 
						||
| 
								 | 
							
								                       const device ftype4 *biasTerms   [[buffer(4)]],
							 | 
						||
| 
								 | 
							
								                       threadgroup ftype4x4 *cols       [[threadgroup(0)]],
							 | 
						||
| 
								 | 
							
								                       ushort3 gid                      [[thread_position_in_grid]],
							 | 
						||
| 
								 | 
							
								                       ushort3 tid                      [[thread_position_in_threadgroup]],
							 | 
						||
| 
								 | 
							
								                       ushort3 thread_size              [[threads_per_threadgroup]]) {
							 | 
						||
| 
								 | 
							
								    short unroll_x = CONV_UNROLL * gid.x;
							 | 
						||
| 
								 | 
							
								    short offset_x = unroll_x * cst.stride_x - cst.pad_x;
							 | 
						||
| 
								 | 
							
								    short offset_y = gid.y * cst.stride_y - cst.pad_y;
							 | 
						||
| 
								 | 
							
								    short sy = max(0, UP_DIV(-offset_y, cst.dilation_y));
							 | 
						||
| 
								 | 
							
								    short ey = min(cst.kernel_y, UP_DIV(cst.input_height - offset_y, cst.dilation_y));
							 | 
						||
| 
								 | 
							
								    auto o_wt = wt + (int)gid.z * cst.input_slice * cst.kernel_size;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    float4x4 result = float4x4(0);
							 | 
						||
| 
								 | 
							
								    short steps = UP_DIV(cst.input_slice, cst.threadgroup_input_slice);
							 | 
						||
| 
								 | 
							
								    for (auto s = 0; s < steps; s++)
							 | 
						||
| 
								 | 
							
								    {
							 | 
						||
| 
								 | 
							
								        int sz_stt = s * cst.threadgroup_input_slice;
							 | 
						||
| 
								 | 
							
								        int sz_end = min(sz_stt + cst.threadgroup_input_slice, cst.input_slice);
							 | 
						||
| 
								 | 
							
								        int sz_size = sz_end - sz_stt;
							 | 
						||
| 
								 | 
							
								        
							 | 
						||
| 
								 | 
							
								        // im2col
							 | 
						||
| 
								 | 
							
								        int z_step = UP_DIV(sz_size, (int)thread_size.z);
							 | 
						||
| 
								 | 
							
								        int z_stt = tid.z * z_step;
							 | 
						||
| 
								 | 
							
								        int z_end = min(z_stt + z_step, sz_size);
							 | 
						||
| 
								 | 
							
								        
							 | 
						||
| 
								 | 
							
								        for (auto z = z_stt; z < z_end; z++) {
							 | 
						||
| 
								 | 
							
								            for (auto ky = sy; ky < ey; ky++) {
							 | 
						||
| 
								 | 
							
								                for (auto kx = 0; kx < cst.kernel_x; kx++) {
							 | 
						||
| 
								 | 
							
								                    auto y_in = in
							 | 
						||
| 
								 | 
							
								                        + (z + sz_stt) * cst.input_size
							 | 
						||
| 
								 | 
							
								                        + (offset_y + ky * cst.dilation_y) * cst.input_width;
							 | 
						||
| 
								 | 
							
								                    int4 x4 = offset_x + kx * cst.dilation_x + cst.stride_x * int4(0, 1, 2, 3);
							 | 
						||
| 
								 | 
							
								                    bool4 valids = 0 <= x4 && x4 < cst.input_width;
							 | 
						||
| 
								 | 
							
								                    cols[z * cst.kernel_size + ky * cst.kernel_x + kx] = {
							 | 
						||
| 
								 | 
							
								                        valids[0] ? y_in[x4[0]] : 0,
							 | 
						||
| 
								 | 
							
								                        valids[1] ? y_in[x4[1]] : 0,
							 | 
						||
| 
								 | 
							
								                        valids[2] ? y_in[x4[2]] : 0,
							 | 
						||
| 
								 | 
							
								                        valids[3] ? y_in[x4[3]] : 0
							 | 
						||
| 
								 | 
							
								                    };
							 | 
						||
| 
								 | 
							
								                }
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								        threadgroup_barrier(mem_flags::mem_threadgroup);
							 | 
						||
| 
								 | 
							
								        
							 | 
						||
| 
								 | 
							
								        // gemm
							 | 
						||
| 
								 | 
							
								        if ((short)gid.z < cst.output_slice) {
							 | 
						||
| 
								 | 
							
								            for (auto z = 0; z < sz_size; z++) {
							 | 
						||
| 
								 | 
							
								                for (auto ky = sy; ky < ey; ky++) {
							 | 
						||
| 
								 | 
							
								                    for (auto kx = 0; kx < cst.kernel_x; kx++) {
							 | 
						||
| 
								 | 
							
								                        auto in4 = cols[ z           * cst.kernel_size + ky * cst.kernel_x + kx];
							 | 
						||
| 
								 | 
							
								                        auto wt4 = o_wt[(z + sz_stt) * cst.kernel_size + ky * cst.kernel_x + kx];
							 | 
						||
| 
								 | 
							
								                        result += {
							 | 
						||
| 
								 | 
							
								                            float4(in4[0] * wt4),
							 | 
						||
| 
								 | 
							
								                            float4(in4[1] * wt4),
							 | 
						||
| 
								 | 
							
								                            float4(in4[2] * wt4),
							 | 
						||
| 
								 | 
							
								                            float4(in4[3] * wt4)
							 | 
						||
| 
								 | 
							
								                        };
							 | 
						||
| 
								 | 
							
								                    }
							 | 
						||
| 
								 | 
							
								                }
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								        
							 | 
						||
| 
								 | 
							
								        if (s == steps - 1) break;
							 | 
						||
| 
								 | 
							
								        threadgroup_barrier(mem_flags::mem_threadgroup);
							 | 
						||
| 
								 | 
							
								    } // end step
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    // save
							 | 
						||
| 
								 | 
							
								    if ((short)gid.z >= cst.output_slice) return;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    float4 b4 = float4(biasTerms[(short)gid.z]);
							 | 
						||
| 
								 | 
							
								    auto off_out = out + (int)gid.z * cst.output_size + (int)gid.y * cst.output_width + unroll_x;
							 | 
						||
| 
								 | 
							
								    bool3 valids = (unroll_x + int3(1, 2, 3)) < cst.output_width;
							 | 
						||
| 
								 | 
							
								    /* true */     off_out[0] = activate((ftype4)(result[0] + b4), cst.activation);
							 | 
						||
| 
								 | 
							
								    if (valids[0]) off_out[1] = activate((ftype4)(result[1] + b4), cst.activation);
							 | 
						||
| 
								 | 
							
								    if (valids[1]) off_out[2] = activate((ftype4)(result[2] + b4), cst.activation);
							 | 
						||
| 
								 | 
							
								    if (valids[2]) off_out[3] = activate((ftype4)(result[3] + b4), cst.activation);
							 | 
						||
| 
								 | 
							
								}
							 |