MNN/tools/script/opencl_kernel_check.py

206 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
import re
import itertools
def run_cmd(args):
from subprocess import Popen, PIPE, STDOUT
stdout, _ = Popen(args, stdout=PIPE, stderr=STDOUT).communicate()
return stdout.decode('utf-8')
def extract_macros(file_content):
"""提取宏定义"""
macros = {}
macros_num = {}
ifdef_pattern = re.compile(r'#(ifdef)\s+(\w+)')
ifndef_pattern = re.compile(r'#(ifndef)\s+(\w+)')
if_pattern = re.compile(r'#(if)\s+(\w+)')
elif_pattern = re.compile(r'#(elif)\s+(\w+)')
defined_pattern = re.compile(r'(defined)\s+(\w+)')
define_pattern = re.compile(r'#(define)\s+(\w+)')
for match in ifdef_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if "LOCAL_SIZE" in macro_name:
macros_num[macro_name] = {1, 2, 3, 4, 16}
else:
macros[macro_name] = None
for match in ifndef_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if "LOCAL_SIZE" in macro_name:
macros_num[macro_name] = {1, 2, 3, 4, 16}
else:
macros[macro_name] = None
for match in if_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if macro_name != "defined":
macros_num[macro_name] = {1, 2, 3, 4, 8}
for match in elif_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if macro_name != "defined":
macros_num[macro_name] = {1, 2, 3, 4, 8}
for match in defined_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
macros[macro_name] = None
for match in define_pattern.finditer(file_content):
macro_type, macro_name = match.groups()
if macro_name in macros:
del macros[macro_name]
if macro_name in macros_num:
del macros_num[macro_name]
if "MNN_SUPPORT_FP16" in macros:
del macros["MNN_SUPPORT_FP16"]
#for macro_name, macro_value in macros.items():
# Replace macro value
#print(f"macro_name {macro_name} macro_value {macro_value}")
return [macros_num, macros]
def compile_with_macros(macros_all, operator_macro, extra_macro, filename, test_for_android):
"""
Tries to compile the kernel given various macro values
"""
macros_num = macros_all[0]
macros = macros_all[1]
float_option = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT=convert_float -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT3=convert_float3 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT=convert_float -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT3=convert_float3 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16"
float_option += " -DINPUT_TYPE_I=float -DINPUT_TYPE_I4=float4 -DINPUT_TYPE=float -DINPUT_TYPE4=float4 -DINPUT_TYPE16=float16 -DRI_DATA=read_imagef -DOUTPUT_TYPE_I=float -DOUTPUT_TYPE_I4=float4 -DCONVERT_OUTPUT_I4=convert_float4 -DOUTPUT_TYPE=float -DOUTPUT_TYPE4=float4 -DOUTPUT_TYPE16=float16 -DCONVERT_OUTPUT4=convert_float4 -DCONVERT_OUTPUT16=convert_float16 -DWI_DATA=write_imagef"
if filename in extra_macro:
float_option += extra_macro[filename]
keys = list(macros.keys())
# 使用 itertools.product 生成所有可能的 0 和 1 的组合
combinations = list(itertools.product([0, 1], repeat=len(keys)))
options_normal = []
# 获取普通的宏定义
for combination in combinations:
option_normal = float_option
macros_out = dict(zip(keys, combination))
for macro_name, macro_value in macros_out.items():
if macro_value == 1:
option_normal += f" -D{macro_name}={macro_value} "
options_normal.append(option_normal)
options_num_normal = []
# 获取有多种取值的宏
if len(macros_num) > 0 :
option_num = ""
for i in {1, 2, 3, 4, 8} :
for macro_name in macros_num:
option_num = f" -D{macro_name}={i} "
for option_normal in options_normal:
options_num_normal.append(option_normal + option_num)
else:
options_num_normal = options_normal
options = []
# 获取OPERATOR的宏, 只需要验证第一个OPERATOR宏与其他宏的各种组合其他的可以只验证一种组合
if len(operator_macro) > 0 :
has_combine = False
for op in operator_macro:
option_operator = f" -DOPERATOR={op} "
if has_combine is True:
options.append(options_num_normal[0] + option_operator)
else:
for option_num_normal in options_num_normal:
options.append(option_num_normal + option_operator)
has_combine = True
else:
options = options_num_normal
with open('option.txt', 'w') as outfile:
for option in options:
outfile.write(option + '\n')
if test_for_android == 1:
run_cmd(['adb', 'push', 'kernel.cl', '/data/local/tmp/MNN'])
run_cmd(['adb', 'push', 'option.txt', '/data/local/tmp/MNN'])
run_cmd(['adb', 'push', 'OpenCLProgramBuildTest.out', '/data/local/tmp/MNN'])
res = run_cmd(['adb', 'shell', 'cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH && ./OpenCLProgramBuildTest.out %s'%(filename)])
print(res)
else:
if sys.platform.startswith('win'):
res = run_cmd(['OpenCLProgramBuildTest.exe', f'{filename}'])
print(res)
else:
res = run_cmd(['./OpenCLProgramBuildTest.out', f'{filename}'])
print(res)
def main():
print("opencl_kernel_check.py path without_subgroup test_for_android")
path = '.'
without_subgroup = 1
test_for_android = 0
if len(sys.argv) > 1:
path = sys.argv[1]
if len(sys.argv) > 2:
without_subgroup = int(sys.argv[2])
if len(sys.argv) > 3:
test_for_android = int(sys.argv[3])
binaryvec_operator = {"in0+in1", "in0*in1", "in0-in1", "in0>in1?in0:in1", "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))",
"in0>in1?in0:in1", "convert_float4(-isgreater(in0,in1))", "convert_float4(-isless(in0,in1))", "convert_float4(-islessequal(in0,in1))", "convert_float4(-isgreaterequal(in0,in1))", "convert_float4(-isequal(in0,in1))",
"floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1",
"pow(in0,in1)", "(in0-in1)*(in0-in1)", "(in1==(float)0?(sign(in0)*(float4)(PI/2)):(atan(in0/in1)+(in1>(float4)0?(float4)0:sign(in0)*(float)PI)))", "convert_float4(-isnotequal(in0,in1))",
"in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1"}
binary_operator = {"in0*in1", "in0+in1", "in0-in1", "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))", "in0>in1?in1:in0", "in0>in1?in0:in1", "(float)(isgreater(in0,in1))",
"(float)(isless(in0,in1))", "(float)(islessequal(in0,in1))", "(float)(isgreaterequal(in0,in1))", "(float)(isequal(in0,in1))", "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))",
"in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", "pow(in0,in1)", "(in0-in1)*(in0-in1)",
"(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))", "(float)(isnotequal(in0,in1))", "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1"}
unary_operator = {"fabs(convert_float4(in))", "in*in", "rsqrt(convert_float4(in)>(float4)(0.000001)?convert_float4(in):(float4)(0.000001))", "-(in)", "exp(convert_float4(in))", "cos(convert_float4(in))", "sin(convert_float4(in))",
"tan(convert_float4(in))", "atan(convert_float4(in))", "sqrt(convert_float4(in))", "ceil(convert_float4(in))", "native_recip(convert_float4(in))", "log1p(convert_float4(in))", "native_log(convert_float4(in)>(float4)(0.0000001)?convert_float4(in):(float4)(0.0000001))",
"floor(convert_float4(in))", "in>(float4)((float)0)?(in+native_log(exp(convert_float4(-(in)))+(float4)(1.0))):(native_log(exp(convert_float4(in))+(float4)(1.0)))", "acosh(convert_float4(in))", "sinh(convert_float4(in))", "asinh(convert_float4(in))",
"atanh(convert_float4(in))", "sign(convert_float4(in))", "round(convert_float4(in))", "cosh(convert_float4(in))", "erf(convert_float4(in))", "erfc(convert_float4(in))", "expm1(convert_float4(in))", "native_recip((float4)1+native_exp(convert_float4(-in)))",
"(convert_float4(in)*native_recip((float4)1+native_exp(convert_float4(-in))))", "tanh(convert_float4(in))", "convert_float4(in)>(float4)(-3.0f)?(convert_float4(in)<(float4)(3.0f)?((convert_float4(in)*(convert_float4(in)+(float4)3.0f))/(float4)6.0f):convert_float4(in)):(float4)(0.0f)",
"gelu(convert_float4(in))", "(erf(convert_float4(in)*(float4)0.7071067932881648)+(float4)1.0)*convert_float4(in)*(float4)0.5", "native_recip((float4)(1.0)+native_exp(convert_float4(-(in))))",
"tanh(convert_float4(in))"}
extra_macro = {}
extra_macro["binary_subgroup_buf.cl"] = " -DINTEL_DATA=uint -DAS_INPUT_DATA=as_float -DAS_INPUT_DATA4=as_float4 -DAS_OUTPUT_DATA4=as_uint4 -DINTEL_SUB_GROUP_READ=intel_sub_group_block_read -DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4 -DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4"
extra_macro["conv_2d_c1_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DINPUT_BLOCK_SIZE=16 -DINPUT_CHANNEL=16 -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1"
extra_macro["conv_2d_c16_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DINPUT_BLOCK_SIZE=16 -DINPUT_CHANNEL=16 -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1"
extra_macro["depthwise_conv2d_subgroup_buf.cl"] = " -DFILTER_HEIGHT=3 -DFILTER_WIDTH=3 -DDILATION_HEIGHT=1 -DDILATION_WIDTH=1 -DSTRIDE_HEIGHT=1 -DSTRIDE_WIDTH=1"
extra_macro["matmul_local_buf.cl"] = " -DOPWM=64 -DOPWN=128 -DCPWK=8 -DOPTM=4 -DOPTN=8"
extra_macro["pooling_subgroup_buf.cl"] = " -DINPUT_LINE_SIZE=16 -DSTRIDE_Y=2 -DSTRIDE_X=2 -DKERNEL_Y=4 -DKERNEL_X=4"
extra_macro["reduction_buf.cl"] = " -DOPERATE(a,b)=(a+b) -DVALUE=0"
extra_macro["reduction.cl"] = " -DOPERATE(a,b)=(a+b) -DVALUE=0"
extra_macro["unary_subgroup_buf.cl"] = " -DINTEL_DATA=uint -DAS_INPUT_DATA=as_float -DAS_INPUT_DATA4=as_float4 -DAS_OUTPUT_DATA4=as_uint4 -DINTEL_SUB_GROUP_READ=intel_sub_group_block_read -DINTEL_SUB_GROUP_READ4=intel_sub_group_block_read4 -DINTEL_SUB_GROUP_WRITE4=intel_sub_group_block_write4"
# 遍历当前目录的所有.cl文件
for filename in os.listdir(path):
if filename.endswith('.cl'):
source_file = os.path.join(path, filename)
with open(source_file, 'r') as file:
file_content = file.read()
with open('kernel.cl', 'w') as outfile:
outfile.write(file_content)
# 提取宏定义
macros_all = extract_macros(file_content)
# Compile with different macro values
operator_macro = {}
if filename == "binary_buf.cl" or filename == "binary.cl" or filename == "loop.cl" or filename == "binary_subgroup_buf.cl":
operator_macro = binaryvec_operator
elif filename == "loop_buf.cl":
operator_macro = binary_operator
elif filename == "unary_buf.cl" or filename == "unary.cl" or filename == "unary_subgroup_buf.cl":
operator_macro = unary_operator
if "subgroup" in filename and without_subgroup == 1:
continue
compile_with_macros(macros_all, operator_macro, extra_macro, filename, test_for_android)
if __name__ == "__main__":
main()