MNN/tools/script/genQNNModelsFromMNN.py

169 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import copy
import argparse
import os
import subprocess
import shutil # 导入 shutil 模块用于删除目录
def generate_all_configs(config_path, graph_name, qnn_sdk_root_path, src_model, executable_path, output_dir):
"""
为每个组合创建子目录生成配置文件并调用C++可执行文件进行模型转换。
"""
# --- 0. 准备工作 ---
# 创建主输出目录
os.makedirs(output_dir, exist_ok=True)
print(f"所有生成的文件将被保存在主目录: '{output_dir}'")
# 定义组合
combinations = [
[36, 'v69'],
[42, 'v69'],
[43, 'v73'],
[57, 'v75'],
[69, 'v79']
]
# --- 1. 读取模板文件 ---
htp_template_file = os.path.join(config_path, "htp_backend_extensions.json")
context_template_file = os.path.join(config_path, "context_config.json")
try:
with open(htp_template_file, 'r', encoding='utf-8') as f:
base_htp_data = json.load(f)
print(f"成功读取模板文件 '{htp_template_file}'")
with open(context_template_file, 'r', encoding='utf-8') as f:
base_context_data = json.load(f)
print(f"成功读取模板文件 '{context_template_file}'")
except FileNotFoundError as e:
print(f"错误:模板文件未找到。请确保 '{e.filename}' 存在于指定的路径中。")
return
except json.JSONDecodeError as e:
print(f"错误:文件格式无效。请检查 {e.doc} 是否为有效的JSON。")
return
# --- 2. 遍历组合,生成文件并执行命令 ---
for soc_id, dsp_arch in combinations:
print(f"\n{'='*15} 处理组合: soc_id={soc_id}, dsp_arch={dsp_arch} {'='*15}")
# --- 新增步骤: 为当前组合创建专用的子目录 ---
new_graph_name = f"{graph_name}_{soc_id}_{dsp_arch}"
graph_specific_dir = output_dir
# --- Part A: 生成 htp_backend_extensions 文件 (路径更新) ---
htp_config_data = copy.deepcopy(base_htp_data)
try:
htp_config_data["graphs"][0]["graph_names"] = [new_graph_name]
htp_config_data["devices"][0]["soc_id"] = soc_id
htp_config_data["devices"][0]["dsp_arch"] = dsp_arch
except (IndexError, KeyError) as e:
print(f"处理组合时出错: '{htp_template_file}' 结构不正确。错误: {e}")
continue
htp_output_filename = f"htp_backend_extensions_{soc_id}_{dsp_arch}.json"
# 更新路径,使其指向新的子目录
htp_output_filepath = os.path.join(graph_specific_dir, htp_output_filename)
with open(htp_output_filepath, 'w', encoding='utf-8') as f:
json.dump(htp_config_data, f, indent=4, ensure_ascii=False)
print(f"-> 已生成配置文件: '{htp_output_filepath}'")
# --- Part B: 生成 context_config 文件 (路径更新) ---
context_config_data = copy.deepcopy(base_context_data)
try:
# 这里的 htp_output_filename 是相对路径,这是正确的,
# 因为 context_config 和 htp_backend_extensions 在同一个目录中。
context_config_data["backend_extensions"]["config_file_path"] = htp_output_filepath
path_template = context_config_data["backend_extensions"]["shared_library_path"]
new_lib_path = path_template.replace("{QNN_SDK_ROOT}", qnn_sdk_root_path)
context_config_data["backend_extensions"]["shared_library_path"] = new_lib_path
except KeyError as e:
print(f"处理组合时出错: '{context_template_file}' 结构不正确,缺少键: {e}")
continue
context_output_filename = f"context_config_{soc_id}_{dsp_arch}.json"
# 更新路径,使其指向新的子目录
context_output_filepath = os.path.join(graph_specific_dir, context_output_filename)
with open(context_output_filepath, 'w', encoding='utf-8') as f:
json.dump(context_config_data, f, indent=4, ensure_ascii=False)
print(f"-> 已生成关联文件: '{context_output_filepath}'")
# --- Part C: 调用C++可执行命令 (路径更新) ---
dst_model_filename = f"{graph_name}_{soc_id}_{dsp_arch}.mnn"
# 更新路径,使其指向新的子目录
dst_model_filepath = os.path.join(graph_specific_dir, dst_model_filename)
graph_product_dir = os.path.join(graph_specific_dir, new_graph_name)
os.makedirs(graph_product_dir, exist_ok=True)
print(f"-> 已创建/确认子目录: '{graph_product_dir}'")
command = [
executable_path,
src_model,
dst_model_filepath,
qnn_sdk_root_path,
new_graph_name,
context_output_filepath
]
print("--> 准备执行命令...")
print(f" $ {' '.join(command)}")
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
print("--> 命令执行成功!")
# 即使成功,也打印 C++ 程序的输出,这对于查看警告等信息很有用
if result.stdout:
print(" --- C++程序输出 (stdout) ---")
print(result.stdout.strip())
print(" ------------------------------")
except FileNotFoundError:
print(f"!!! 命令执行失败: 可执行文件未找到 '{executable_path}'。请检查路径。")
break # 如果可执行文件找不到,直接退出循环
except subprocess.CalledProcessError as e:
# 这是关键的修改部分
print(f"!!! 命令执行失败 (返回码: {e.returncode})")
# 检查并打印 C++ 程序在失败前产生的标准输出
if e.stdout:
print(" --- C++程序输出 (stdout) ---")
print(e.stdout.strip())
print(" ------------------------------")
# 检查并打印 C++ 程序在失败前产生的标准错误(错误日志通常在这里)
if e.stderr:
print(" --- C++程序错误 (stderr) ---")
print(e.stderr.strip())
print(" ------------------------------")
except Exception as e:
print(f"!!! 执行期间发生未知错误: {e}")
finally:
# --- 步骤 3: 清理 ---
# 检查目录是否存在,然后删除
if os.path.exists(graph_product_dir):
print(f"--> 清理临时文件和目录: '{graph_product_dir}'")
shutil.rmtree(graph_product_dir)
else:
print("--> 无需清理,临时目录未创建。")
# --- 脚本执行入口 ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="为多个组合创建子目录生成QNN配置文件并调用模型转换工具。",
formatter_class=argparse.RawTextHelpFormatter
)
# ... (argparse部分保持完全不变) ...
gen_group = parser.add_argument_group('文件生成参数')
gen_group.add_argument("--config_path", required=True, help="[必需] 包含模板文件的目录路径。")
gen_group.add_argument("--graph_name", required=True, help="[必需] 模型图的名称 (不含soc_id等后缀)。")
gen_group.add_argument("--qnn_sdk_root_path", required=True, help="[必需] QNN SDK 的根路径。")
exec_group = parser.add_argument_group('模型转换参数')
exec_group.add_argument("--src_model", required=True, help="[必需] 源模型文件路径 (例如: my_model.mnn)。")
exec_group.add_argument("--executable_path", required=True, help="[必需] C++模型转换可执行文件的路径。")
exec_group.add_argument("--output_dir", default="./qnn_models", help="存放所有生成文件的输出目录 (默认: ./qnn_models)。")
args = parser.parse_args()
generate_all_configs(args.config_path, args.graph_name, args.qnn_sdk_root_path, args.src_model, args.executable_path, args.output_dir)