mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			36 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			36 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| from os import makedirs
 | |
| from os.path import join, basename, exists
 | |
| from shutil import copy, rmtree
 | |
| import argparse
 | |
| 
 | |
| parser = argparse.ArgumentParser(description='Pull Test model and data from AliNNModel')
 | |
| parser.add_argument('--alinnmodel_path', dest='src_path', required=True, help='AliNNModel project path')
 | |
| parser.add_argument('--playground_path', dest='dest_path', required=True, help='Test Playground path')
 | |
| parser.add_argument('--models', dest='models', type=str, nargs='+', help='target models')
 | |
| args = parser.parse_args()
 | |
| 
 | |
| def main():
 | |
|     src_path = join(args.src_path, 'TestResource')
 | |
|     dest_path = join(args.dest_path, 'models')
 | |
|     if exists(dest_path):
 | |
|         rmtree(dest_path)
 | |
|     makedirs(dest_path)
 | |
|     if args.models is not None and len(args.models) > 0:
 | |
|         model_dirs = [join(src_path, m) for m in args.models]
 | |
|     else:
 | |
|         model_dirs = [f.path for f in os.scandir(src_path) if f.is_dir()]
 | |
|     model_names_record_path = join(args.dest_path, 'model_names.txt')
 | |
|     with open(model_names_record_path, 'w') as f:
 | |
|         for model_dir in model_dirs:
 | |
|             model_name = basename(model_dir)
 | |
|             f.write(model_name + '\n')
 | |
|             dest_dir = join(dest_path, model_name)
 | |
|             makedirs(dest_dir)
 | |
|             copy(join(model_dir, 'temp.bin'), dest_dir)
 | |
|             copy(join(model_dir, 'input_0.txt'), dest_dir)
 | |
|             copy(join(model_dir, 'output.txt'), dest_dir)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     main()
 |