mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			199 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			199 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
| # -*- coding: UTF-8 -*-
 | |
| import os
 | |
| import sys
 | |
| import MNN
 | |
| import numpy as np
 | |
| 
 | |
| total_num = 0
 | |
| wrongs = []
 | |
| 
 | |
| def parseConfig(root_dir):
 | |
|     configName = os.path.join(root_dir, 'config.txt')
 | |
|     if not os.path.exists(configName):
 | |
|         return False
 | |
|     try:
 | |
|         config = open(configName, 'rt', encoding='utf-8')
 | |
|     except:
 | |
|         import io
 | |
|         config = io.open(configName, 'rt', encoding='utf-8')
 | |
|     res = {}
 | |
|     res['model_name'] = os.path.join(root_dir, 'temp.bin')
 | |
|     for line in config.readlines():
 | |
|         if line[0] == '#':
 | |
|             continue
 | |
|         value = line[line.find(' = ') + 3:].strip()
 | |
|         if 'input_size' in line:
 | |
|             res['input_size'] = int(value)
 | |
|         elif 'input_names' in line:
 | |
|             input_names = value.split(',')
 | |
|             res['input_names'] = input_names
 | |
|             res['given_names'] = [ os.path.join(root_dir, x + '.txt') for x in input_names]
 | |
|         elif 'input_dims' in line:
 | |
|             res['input_dims'] = []
 | |
|             for val in value.split(','):
 | |
|                 res['input_dims'].append([int(x) for x in val.split('x')])
 | |
|         elif 'output_size' in line:
 | |
|             res['output_size'] = int(value)
 | |
|         elif 'output_names' in line:
 | |
|             output_names = value.split(',')
 | |
|             res['output_names'] = output_names
 | |
|             res['expect_names'] = []
 | |
|             for i in range(len(output_names)):
 | |
|                 expect_name = os.path.join(root_dir, output_names[i] + '.txt')
 | |
|                 if os.path.exists(expect_name):
 | |
|                     res['expect_names'].append(expect_name)
 | |
|                 else:
 | |
|                     res['expect_names'].append(os.path.join(root_dir, str(i) + '.txt'))
 | |
| 
 | |
|     return res
 | |
| 
 | |
| def loadtxt(file, shape, dtype=np.float32):
 | |
|     size  = np.prod(shape)
 | |
|     try:
 | |
|         data = np.loadtxt(fname=file, dtype=dtype).flatten()
 | |
|     except:
 | |
|         data = []
 | |
|         data_file = open(file, 'rt')
 | |
|         for line in data_file.readlines():
 | |
|             for x in line.strip().split(' '):
 | |
|                 try:
 | |
|                     a = float(x)
 | |
|                     data.append(a)
 | |
|                 except:
 | |
|                     pass
 | |
|         data = np.asarray(data)
 | |
|     if data.size >= size:
 | |
|         data = data[:size].reshape(shape)
 | |
|     else:
 | |
|         data = np.pad(data, (0, size - data.size), 'constant').reshape(shape)
 | |
|     return data
 | |
| 
 | |
| def MNNDataType2NumpyDataType(data_type):
 | |
|     if data_type == MNN.Halide_Type_Uint8:
 | |
|         return np.uint8
 | |
|     elif data_type == MNN.Halide_Type_Double:
 | |
|         return np.float64
 | |
|     elif data_type == MNN.Halide_Type_Int:
 | |
|         return np.int32
 | |
|     elif data_type == MNN.Halide_Type_Int64:
 | |
|         return np.int64
 | |
|     else:
 | |
|         return np.float32
 | |
| 
 | |
| def createTensor(tensor, file=''):
 | |
|     shape = tensor.getShape()
 | |
|     data_type = tensor.getDataType()
 | |
|     dtype = MNNDataType2NumpyDataType(data_type)
 | |
|     if file == '':
 | |
|         data = np.ones(shape, dtype=dtype)
 | |
|     else:
 | |
|         data = loadtxt(file, shape, dtype)
 | |
|     return MNN.Tensor(shape, tensor.getDataType(), data, tensor.getDimensionType())
 | |
| 
 | |
| def compareTensor(tensor, file, atol=5e-2):
 | |
|     outputNumpyData = tensor.getNumpyData()
 | |
|     expectNumpyData = loadtxt(file, tensor.getShape())
 | |
|     return np.allclose(outputNumpyData, expectNumpyData, atol=atol)
 | |
| 
 | |
| def log_result(success, model):
 | |
|     global total_num
 | |
|     global wrongs
 | |
|     total_num += 1
 | |
|     if success:
 | |
|         print('Test %s Correct!'%model)
 | |
|     else:
 | |
|         wrongs.append(model)
 | |
|         print('Test Failed %s!'%model)
 | |
| 
 | |
| def modelTest(modelPath, givenName, expectName):
 | |
|     print("Testing model %s, input: %s, output: %s\n" % (modelPath, givenName, expectName))
 | |
| 
 | |
|     net = MNN.Interpreter(modelPath)
 | |
|     session = net.createSession()
 | |
|     allInput = net.getSessionInputAll(session)
 | |
|     # input
 | |
|     inputTensor = net.getSessionInput(session)
 | |
|     inputHost = createTensor(inputTensor, givenName)
 | |
|     inputTensor.copyFrom(inputHost)
 | |
|     # infer
 | |
|     net.runSession(session)
 | |
|     outputTensor = net.getSessionOutput(session)
 | |
|     # output
 | |
|     outputShape = outputTensor.getShape()
 | |
|     outputHost = createTensor(outputTensor)
 | |
|     outputTensor.copyToHostTensor(outputHost)
 | |
|     # compare
 | |
|     success = compareTensor(outputHost, expectName)
 | |
|     log_result(success, modelPath)
 | |
| 
 | |
| def modelTestWithConfig(config):
 | |
|     model  = config['model_name']
 | |
|     inputs = config['input_names']
 | |
|     shapes = config['input_dims']
 | |
|     givens = config['given_names']
 | |
|     outputs = config['output_names']
 | |
|     expects = config['expect_names']
 | |
|     print("Testing model %s, input: %s, output: %s\n" % (model, givens, expects))
 | |
|     net = MNN.Interpreter(config['model_name'])
 | |
|     session = net.createSession()
 | |
|     all_input = net.getSessionInputAll(session)
 | |
|     # resize
 | |
|     for i in range(len(inputs)):
 | |
|         input = inputs[i]
 | |
|         shape = shapes[i]
 | |
|         net.resizeTensor(all_input[input], tuple(shape))
 | |
|     net.resizeSession(session)
 | |
|     # input
 | |
|     all_input = net.getSessionInputAll(session)
 | |
|     for i in range(len(inputs)):
 | |
|         input = inputs[i]
 | |
|         given = givens[i]
 | |
|         input_tensor = all_input[input]
 | |
|         input_host = createTensor(input_tensor, given)
 | |
|         input_tensor.copyFrom(input_host)
 | |
|     # infer
 | |
|     net.runSession(session)
 | |
|     all_output = net.getSessionOutputAll(session)
 | |
|     # output & compare
 | |
|     success = True
 | |
|     for i in range(len(outputs)):
 | |
|         output = outputs[i]
 | |
|         expect = expects[i]
 | |
|         output_tensor = all_output[output]
 | |
|         output_host = createTensor(output_tensor)
 | |
|         output_tensor.copyToHostTensor(output_host)
 | |
|         success &= compareTensor(output_host, expect)
 | |
|     # res
 | |
|     log_result(success, model)
 | |
| 
 | |
| def testResource(model_root_dir, name):
 | |
|     root_dir = os.path.join(model_root_dir, 'TestResource')
 | |
|     print('root: ' + root_dir + '\n')
 | |
|     for name in os.listdir(root_dir):
 | |
|         if name == '.DS_Store':
 | |
|             continue
 | |
|         modelName = os.path.join(root_dir, name, 'temp.bin')
 | |
|         inputName = os.path.join(root_dir, name, 'input_0.txt')
 | |
|         outputName = os.path.join(root_dir, name, 'output.txt')
 | |
|         modelTest(modelName, inputName, outputName)
 | |
| 
 | |
| def testTestWithDescribe(model_root_dir):
 | |
|     root_dir = os.path.join(model_root_dir, 'TestWithDescribe')
 | |
|     print('root: ' + root_dir + '\n')
 | |
|     for name in os.listdir(root_dir):
 | |
|         if name == '.DS_Store':
 | |
|             continue
 | |
|         config = parseConfig(os.path.join(root_dir, name))
 | |
|         if config:
 | |
|             modelTestWithConfig(config)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     model_root_dir = sys.argv[1]
 | |
|     testResource(model_root_dir, 'TestResource')
 | |
|     testResource(model_root_dir, 'OpTestResource')
 | |
|     testTestWithDescribe(model_root_dir)
 | |
|     if len(wrongs) > 0:
 | |
|         print('Wrong: ', len(wrongs))
 | |
|         for wrong in wrongs:
 | |
|             print(wrong)
 | |
|     print('TEST_NAME_PYMNN_MODEL: Pymnn模型测试\nTEST_CASE_AMOUNT_PYMNN_MODEL: {\"blocked\":0,\"failed\":%d,\"passed\":%d,\"skipped\":0}\n'%(len(wrongs), total_num - len(wrongs))) |