mirror of https://github.com/alibaba/MNN.git
Merge pull request #4061 from zlaazlaa/fix_diffusion
fix(diffusion): simplify export logic and fix dynamic axes GitOrigin-RevId: cc6faf47f33d462e2e1ac613ec710ce55c39a86a
This commit is contained in:
parent
308b8cd6e8
commit
fe508afc19
|
|
@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai
|
|||
cd mnn_path/transformers/diffusion/export
|
||||
python onnx_export.py \
|
||||
--model_path hf_sd_load_path \
|
||||
--output_path onnx_save_path
|
||||
--output_path onnx_save_path \
|
||||
--opset 18
|
||||
```
|
||||
注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境:
|
||||
```
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||
text_input = pipeline.tokenizer(
|
||||
"A sample prompt",
|
||||
["A sample prompt", "A sample prompt"],
|
||||
padding="max_length",
|
||||
max_length=pipeline.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
|
|
@ -97,9 +97,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
dynamic_axes=None,
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.text_encoder
|
||||
|
|
@ -117,13 +115,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
# False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
|
||||
ordered_input_names=["sample", "timestep", "encoder_hidden_states"],
|
||||
output_names=["out_sample"], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
dynamic_axes=None,
|
||||
opset=opset,
|
||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
|
|
@ -149,7 +143,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(
|
||||
|
|
@ -159,30 +153,24 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
dynamic_axes=None,
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0]
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
ordered_input_names=["latent_sample"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
dynamic_axes=None,
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.vae
|
||||
|
|
|
|||
Loading…
Reference in New Issue