Compare commits

...

165 Commits

Author SHA1 Message Date
Alexander Piskun acbf08cd60
feat(api-nodes): add support for 720p resolution for Kling Omni nodes (#11604) 2026-01-03 23:05:02 -08:00
comfyanonymous 53e762a3af
Print memory summary on OOM to help with debugging. (#11613) 2026-01-03 22:28:38 -05:00
comfyanonymous 9a552df898
Remove leftover scaled_fp8 key. (#11603) 2026-01-02 17:28:10 -08:00
Alexander Piskun f2fda021ab
Tripo3D: pass face_limit parameter only when it differs from default (#11601) 2026-01-02 03:18:43 -08:00
throttlekitty 303b1735f8
Give Mahiro CFG a more appropriate display name (#11580) 2026-01-02 00:37:37 -08:00
Alexander Piskun 9e5f677746
Ignore all frames except the first one for MPO format. (#11569) 2026-01-02 00:35:34 -08:00
comfyanonymous 65cfcf5b1b
New Year ruff cleanup. (#11595) 2026-01-01 22:06:14 -05:00
comfyanonymous 1bdc9a947f
Remove duplicate import of model_management (#11587) 2025-12-31 19:29:55 -05:00
comfyanonymous d622a61874
Refactor: move clip_preprocess to comfy.clip_model (#11586) 2025-12-31 17:38:36 -05:00
ComfyUI Wiki 236b9e211d
chore: update workflow templates to v0.7.65 (#11579) 2025-12-31 13:38:39 -08:00
Alexander Piskun 6ca3d5c011
fix(api-nodes-vidu): preserve percent-encoding for signed URLs (#11564) 2025-12-30 20:12:38 -08:00
Jedrzej Kosinski 0be8a76c93
V3 Improvements + DynamicCombo + Autogrow exposed in public API (#11345)
* Support Combo outputs in a more sane way

* Remove test validate_inputs function on test node

* Make curr_prefix be a list of strings instead of string for easier parsing as keys get added to dynamic types

* Start to account for id prefixes from frontend, need to fix bug with nested dynamics

* Ensure inputs/outputs/hidden are lists in schema finalize function, remove no longer needed 'is not None' checks

* Add raw_link and extra_dict to all relevant Inputs

* Make nested DynamicCombos work properly with prefixed keys on latest frontend; breaks old Autogrow, but is pretty much ready for upcoming Autogrow keys

* Replace ... usage with a MISSING sentinel for clarity in nodes_logic.py

* Added CustomCombo node in backend to reflect frontend node

* Prepare Autogrow's expand_schema_for_dynamic to work with upcoming frontend changes

* Prepare for look up table for dynamic input stuff

* More progress towards dynamic input lookup function stuff

* Finished converting _expand_schema_for_dynamic to be done via lookup instead of OOP to guarantee working with process isolation, did refactoring to remove old implementation + cleaning INPUT_TYPES definition including v3 hidden definition

* Change order of functions

* Removed some unneeded functions after dynamic refactor

* Make MatchType's output default displayname "MATCHTYPE"

* Fix DynamicSlot get_all

* Removed redundant code - dynamic stuff no longer happens in OOP way

* Natively support AnyType (*) without __ne__ hacks

* Remove stray code that made it in

* Remove expand_schema_for_dynamic left over on DynamicInput class

* get_dynamic() on DynamicInput/Output was not doing anything anymore, so removed it

* Make validate_inputs validate combo input correctly

* Temporarily comment out conversion to 'new' (9 month old) COMBO format in get_input_info

* Remove refrences to resources feature scrapped from V3

* Expose DynamicCombo in public API

* satisfy ruff after some code got commented out

* Make missing input error prettier for dynamic types

* Created a Switch2 node as a side-by-side test, will likely go with Switch2 as the initial switch node

* Figured out Switch situation

* Pass in v3_data in IsChangedCache.get function's fingerprint_inputs, add a from_v3_data helper method to HiddenHolder

* Switch order of Switch and Soft Switch nodes in file

* Temp test node for MatchType

* Fix missing v3_data for v1 nodes in validation

* For now, remove chacking duplicate id's for dynamic types

* Add Resize Image/Mask node that thanks to MatchType+DynamicCombo is 16-nodes-in-1

* Made DynamicCombo references in DCTestNode use public interface

* Add an AnyTypeTestNode

* Make lazy status for specific inputs on DynamicInputs work by having the values of the dictionary for check_lazy_status be a tuple, where the second element is the key of the input that can be returned

* Comment out test logic nodes

* Make primitive float's step make more sense

* Add (and leave commented out) some potential logic nodes

* Change default crop option to "center" on Resize Image/Mask node

* Changed copy.copy(d) to d.copy()

* Autogrow is available in stable  frontend, so exposing it in public API

* Use outputs id as display_name if no display_name present, remove v3 outputs id restriction that made them have to have unique IDs from the inputs

* Enable Custom Combo node as stable frontend now supports it

* Make id properly act like display_name on outputs

* Add Batch Images/Masks/Latents node

* Comment out Batch Images/Masks/Latents node for now, as Autogrow has a bug with MatchType where top connection is disconnected upon refresh

* Removed code for a couple test nodes in nodes_logic.py

* Add Batch Images, Batch Masks, and Batch Latents nodes with Autogrow, deprecate old Batch Images + LatentBatch nodes
2025-12-30 23:09:55 -05:00
mengqin 0357ed7ec4
Add support for sage attention 3 in comfyui, enable via new cli arg (#11026)
* Add support for sage attention 3 in comfyui, enable via new cli arg
--use-sage-attiention3

* Fix some bugs found in PR review. The N dimension at which Sage
Attention 3 takes effect is reduced to 1024 (although the improvement is
not significant at this scale).

* Remove the Sage Attention3 switch, but retain the attention function
registration.

* Fix a ruff check issue in attention.py
2025-12-30 22:53:52 -05:00
comfyanonymous f59f71cf34 ComfyUI version v0.7.0 2025-12-30 22:41:22 -05:00
drozbay 178bdc5e14
Add handling for vace_context in context windows (#11386)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-30 14:40:42 -08:00
Alexander Piskun 25a1bfab4e
chore(api-nodes-bytedance): mark "seededit" as deprecated, adjust display name of Seedream (#11490) 2025-12-30 08:33:34 -08:00
Tavi Halperin d7111e426a
ResizeByLongerSide: support video (#11555)
(cherry picked from commit 98c6840aa4e5fd5407ba9ab113d209011e474bf6)
2025-12-29 17:07:29 -08:00
comfyanonymous 0e6221cc79
Add some warnings for pin and unpin errors. (#11561) 2025-12-29 18:26:42 -05:00
rattus 9ca7e143af
mm: discard async errors from pinning failures (#10738)
Pretty much every error cudaHostRegister can throw also queues the same
error on the async GPU queue. This was fixed for repinning error case,
but there is the bad mmap and just enomem cases that are harder to
detect.

Do some dummy GPU work to clean the error state.
2025-12-29 18:19:34 -05:00
comfyanonymous 8fd07170f1
Comment out unused norm_final in lumina/z image model. (#11545) 2025-12-28 22:07:25 -05:00
comfyanonymous 2943093a53
Enable async offload by default for AMD. (#11534) 2025-12-27 18:54:15 -05:00
Alexander Piskun 36deef2c57
chore(api-nodes): switch to credits instead of $ (#11489) 2025-12-26 19:56:52 -08:00
Alexander Piskun 0d2e4bdd44
fix(api-nodes-gemini): always force enhance_prompt to be True (#11503) 2025-12-26 19:55:30 -08:00
Alexander Piskun eff4ea0b62
[V3] converted nodes_images.py to V3 schema (#11206)
* converted nodes_images.py to V3 schema

* fix test
2025-12-26 19:39:02 -08:00
Alexander Piskun 865568b7fc
feat(api-nodes): add Kling Motion Control node (#11493) 2025-12-26 19:16:21 -08:00
comfyanonymous 1e4e342f54
Fix noise with ancestral samplers when inferencing on cpu. (#11528) 2025-12-26 22:03:01 -05:00
Dr.Lt.Data 16fb6849d2
bump comfyui_manager version to the 4.0.4 (#11521) 2025-12-27 08:55:59 +09:00
comfyanonymous d9a76cf66e
Specify in readme that we only support pytorch 2.4 and up. (#11512) 2025-12-25 23:46:51 -05:00
comfyanonymous 532e285079
Add a ManualSigmas node. (#11499)
Can be used to manually set the sigmas for a model.

This node accepts a list of integer and floating point numbers separated
with any non numeric character.
2025-12-24 19:09:37 -05:00
ComfyUI Wiki 4f067b07fb
chore: update workflow templates to v0.7.64 (#11496) 2025-12-24 18:54:21 -05:00
Comfy Org PR Bot 650e716dda
Bump comfyui-frontend-package to 1.35.9 (#11470)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-23 21:29:41 -08:00
comfyanonymous e4c61d7555 ComfyUI v0.6.0 2025-12-23 20:50:02 -05:00
ComfyUI Wiki 22ff1bbfcb
chore: update workflow templates to v0.7.63 (#11482) 2025-12-23 20:48:45 -05:00
Alexander Piskun f4f44bb807
api-nodes: use new custom endpoint for Nano Banana (#11311) 2025-12-23 12:10:27 -08:00
comfyanonymous 33aa808713
Make denoised output on custom sampler nodes work with nested tensors. (#11471) 2025-12-22 16:43:24 -05:00
ComfyUI Wiki eb0e10aec4
Update workflow templates to v0.7.62 (#11467) 2025-12-22 16:02:41 -05:00
Alexander Piskun c176b214cc
extend possible duration range for Kling O1 StartEndFrame node (#11451) 2025-12-21 22:44:49 -08:00
comfyanonymous 91bf6b6aa3
Add node to create empty latents for qwen image layered model. (#11460) 2025-12-21 19:59:40 -05:00
comfyanonymous 807538fe6c
Core release process. (#11447) 2025-12-20 20:02:02 -05:00
Alexander Piskun bbb11e2608
fix(api-nodes): Topaz 4k video upscaling (#11438) 2025-12-20 08:48:28 -08:00
Alexander Piskun 0899012ad6
chore(api-nodes): by default set Watermark generation to False (#11437) 2025-12-19 22:24:37 -08:00
comfyanonymous fb478f679a
Only apply gemma quant config to gemma model for newbie. (#11436) 2025-12-20 01:02:43 -05:00
woctordho 4c432c11ed
Implement Jina CLIP v2 and NewBie dual CLIP (#11415)
* Implement Jina CLIP v2

* Support quantized Gemma in NewBie dual CLIP
2025-12-20 00:57:22 -05:00
comfyanonymous 31e961736a
Fix issue with batches and newbie. (#11435) 2025-12-20 00:23:51 -05:00
rattus 767ee30f21
ZImageFunControlNet: Fix mask concatenation in --gpu-only (#11421)
This operation trades in latents which in --gpu-only may be out of the GPU
The two VAE results will follow the --gpu-only defined behaviour so follow
the inpaint image device when calculating the mask in this path.
2025-12-20 00:22:17 -05:00
comfyanonymous 3ab9748903
Disable prompt weights on newbie te. (#11434) 2025-12-20 00:19:47 -05:00
woctordho 0aa7fa464e
Implement sliding attention in Gemma3 (#11409) 2025-12-20 00:16:46 -05:00
drozbay 514c24d756
Fix error from logging line (#11423)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-19 20:22:45 -08:00
comfyanonymous 809ce68749
Support nested tensor denoise masks. (#11431) 2025-12-19 19:59:25 -05:00
BradPepersAMD cc4ddba1b6
Allow enabling use of MIOpen by setting COMFYUI_ENABLE_MIOPEN=1 as an env var (#11366) 2025-12-19 17:01:50 -05:00
Dr.Lt.Data 8376ff6831
bump comfyui_manager version to the 4.0.3b7 (#11422) 2025-12-19 10:41:56 -08:00
Alexander Piskun 5b4d0664c8
add Flux2MaxImage API Node (#11420) 2025-12-19 10:02:49 -08:00
comfyanonymous 894802b0f9
Add LatentCutToBatch node. (#11411) 2025-12-18 22:21:40 -05:00
comfyanonymous 28eaab608b
Diffusion model part of Qwen Image Layered. (#11408)
Only thing missing after this is some nodes to make using it easier.
2025-12-18 20:21:14 -05:00
comfyanonymous 6a2678ac65
Trim/pad channels in VAE code. (#11406) 2025-12-18 18:22:38 -05:00
comfyanonymous e4fb3a3572
Support loading Wan/Qwen VAEs with different in/out channels. (#11405) 2025-12-18 17:45:33 -05:00
ComfyUI Wiki e8ebbe668e
chore: update workflow templates to v0.7.60 (#11403) 2025-12-18 17:09:29 -05:00
ric-yu 1ca89b810e
Add unified jobs API with /api/jobs endpoints (#11054)
* feat: create a /jobs api to return queue and history jobs

* update unused vars

* include priority

* create jobs helper file

* fix ruff

* update how we set error message

* include execution error in both responses

* rename error -> failed, fix output shape

* re-use queue and history functions

* set workflow id

* allow srot by exec duration

* fix tests

* send priority and remove error msg

* use ws messages to get start and end times

* revert main.py fully

* refactor: move all /jobs business logic to jobs.py

* fix failing test

* remove some tests

* fix non dict nodes

* address comments

* filter by workflow id and remove null fields

* add clearer typing - remove get("..") or ..

* refactor query params to top get_job(s) doc, add remove_sensitive_from_queue

* add brief comment explaining why we skip animated

* comment that format field is for frontend backward compatibility

* fix whitespace

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: guill <jacob.e.segal@gmail.com>
2025-12-17 21:44:31 -08:00
comfyanonymous bf7dc63bd6
skip_load_model -> force_full_load (#11390)
This should be a bit more clear and less prone to potential breakage if the
logic of the load models changes a bit.
2025-12-17 23:29:32 -05:00
Kohaku-Blueleaf 86dbb89fc9
Resolution bucketing and Trainer implementation refactoring (#11117) 2025-12-17 22:15:27 -05:00
comfyanonymous ba6080bbab ComfyUI v0.5.1 2025-12-17 21:04:50 -05:00
comfyanonymous 16d85ea133
Better handle torch being imported by prestartup nodes. (#11383) 2025-12-17 19:43:18 -05:00
chaObserv 5d9ad0c6bf
Fix the last step with non-zero sigma in sa_solver (#11380) 2025-12-17 13:57:40 -05:00
Alexander Piskun c08f97f344
fix regression in V3 nodes processing (#11375) 2025-12-17 10:24:25 -08:00
Alexander Piskun 887143854b
feat(api-nodes): add GPT-Image-1.5 (#11368) 2025-12-17 09:43:41 -08:00
comfyanonymous 3a5f239cb6 ComfyUI v0.5.0 2025-12-17 03:46:11 -05:00
chaObserv 827bb1512b
Add exp_heun_2_x0 sampler series (#11360) 2025-12-16 23:35:43 -05:00
comfyanonymous ffdd53b327
Check state dict key to auto enable the index_timestep_zero ref method. (#11362) 2025-12-16 17:03:17 -05:00
Alexander Piskun 65e2103b09
feat(api-nodes): add Wan2.6 model to video nodes (#11357) 2025-12-16 13:51:48 -08:00
Benjamin Lu 9304e47351
Update workflows for new release process (#11064)
* Update release workflows for branch process

* Adjust branch order in workflow triggers

* Revert changes in test workflows
2025-12-15 23:24:18 -08:00
comfyanonymous bc606d7d64
Add a way to set the default ref method in the qwen image code. (#11349) 2025-12-16 01:26:55 -05:00
comfyanonymous 645ee1881e
Inpainting for z image fun control. Use the ZImageFunControlnet node. (#11346)
image -> control image ex: pose
inpaint_image -> image for inpainting
mask -> inpaint mask
2025-12-15 23:38:12 -05:00
Christian Byrne 3d082c3206
bump comfyui-frontend-package to 1.34.9 (patch) (#11342) 2025-12-15 23:35:37 -05:00
comfyanonymous 683569de55
Only enable fp16 on ZImage on newer pytorch. (#11344) 2025-12-15 22:33:27 -05:00
Haoming ea2c117bc3
[BlockInfo] Wan (#10845)
* block info

* animate

* tensor

* device

* revert
2025-12-15 17:59:16 -08:00
Haoming fc4af86068
[BlockInfo] Lumina (#11227)
* block info

* device

* Make tensor int again

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-12-15 17:57:28 -08:00
comfyanonymous 41bcf0619d
Add code to detect if a z image fun controlnet is broken or not. (#11341) 2025-12-15 20:51:06 -05:00
seed93 d02d0e5744
[add] tripo3.0 (#10663)
* [add] tripo3.0

* [tripo] change paramter order

* change order

---------

Co-authored-by: liangd <liangding@vastai3d.com>
2025-12-15 17:38:46 -08:00
comfyanonymous 70541d4e77
Support the new qwen edit 2511 reference method. (#11340)
index_timestep_zero can be selected in the
FluxKontextMultiReferenceLatentMethod now with the display name set to the
more generic "Edit Model Reference Method" node.
2025-12-15 19:20:34 -05:00
drozbay 77b2f7c228
Add context windows callback for custom cond handling (#11208)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-15 16:06:32 -08:00
Alexander Piskun 43e0d4e3cc
comfy_api: remove usage of "Type","List" and "Dict" types (#11238) 2025-12-15 16:01:10 -08:00
Dr.Lt.Data dbd330454a
feat(preview): add per-queue live preview method override (#11261)
- Add set_preview_method() to override live preview method per queue item
- Read extra_data.preview_method from /prompt request
- Support values: taesd, latent2rgb, none, auto, default
- "default" or unset uses server's CLI --preview-method setting
- Add 44 tests (37 unit + 7 E2E)
2025-12-15 15:57:39 -08:00
Alexander Piskun 33c7f1179d
drop Pika API nodes (#11306) 2025-12-15 15:32:29 -08:00
Alexander Piskun af91eb6c99
api-nodes: drop Kling v1 model (#11307) 2025-12-15 15:30:24 -08:00
comfyanonymous 5cb1e0c9a0
Disable guards on transformer_options when torch.compile (#11317) 2025-12-15 16:49:29 -05:00
ComfyUI Wiki 51347f9fb8
chore: update workflow templates to v0.7.59 (#11337) 2025-12-15 16:28:55 -05:00
Dr.Lt.Data a5e85017d8
bump manager requirments to the 4.0.3b5 (#11324) 2025-12-15 14:24:01 -05:00
comfyanonymous 5ac3b26a7d
Update warning for old pytorch version. (#11319)
Versions below 2.4 are no longer supported. We will not break support on purpose but will not fix it if we do.
2025-12-14 04:02:50 -05:00
chaObserv 6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
comfyanonymous 971cefe7d4
Fix pytorch warnings. (#11314) 2025-12-13 18:45:23 -05:00
comfyanonymous da2bfb5b0a
Basic implementation of z image fun control union 2.0 (#11304)
The inpaint part is currently missing and will be implemented later.

I think they messed up this model pretty bad. They added some
control_noise_refiner blocks but don't actually use them. There is a typo
in their code so instead of doing control_noise_refiner -> control_layers
it runs the whole control_layers twice.

Unfortunately they trained with this typo so the model works but is kind
of slow and would probably perform a lot better if they corrected their
code and trained it again.
2025-12-13 01:39:11 -05:00
comfyanonymous c5a47a1692
Fix bias dtype issue in mixed ops. (#11293) 2025-12-12 11:49:35 -05:00
Alexander Piskun 908fd7d749
feat(api-nodes): new TextToVideoWithAudio and ImageToVideoWithAudio nodes (#11267) 2025-12-12 00:18:31 -08:00
comfyanonymous 5495589db3
Respect the dtype the op was initialized in for non quant mixed op. (#11282) 2025-12-11 23:32:27 -05:00
Jukka Seppänen 982876d59a
WanMove support (#11247) 2025-12-11 22:29:34 -05:00
comfyanonymous 338d9ae3bb
Make portable updater work with repos in unmerged state. (#11281) 2025-12-11 18:56:33 -05:00
comfyanonymous eeb020b9b7
Better chroma radiance and other models vram estimation. (#11278) 2025-12-11 17:33:09 -05:00
comfyanonymous ae65433a60
This only works on radiance. (#11277) 2025-12-11 17:15:00 -05:00
comfyanonymous fdebe18296
Fix regular chroma radiance (#11276) 2025-12-11 17:09:35 -05:00
comfyanonymous f8321eb57b
Adjust memory usage factor. (#11257) 2025-12-11 01:30:31 -05:00
Alexander Piskun 93948e3fc5
feat(api-nodes): enable Kling Omni O1 node (#11229) 2025-12-10 22:11:12 -08:00
Farshore e711aaf1a7
Lower VAE loading requirements:Create a new branch for GPU memory calculations in qwen-image vae (#11199) 2025-12-10 22:02:26 -05:00
Johnpaul Chiwetelu 57ddb7fd13
Fix: filter hidden files from /internal/files endpoint (#11191) 2025-12-10 21:49:49 -05:00
comfyanonymous 17c92a9f28
Tweak Z Image memory estimation. (#11254) 2025-12-10 19:59:48 -05:00
Alexander Piskun 36357bbcc3
process the NodeV1 dict results correctly (#11237) 2025-12-10 11:55:09 -08:00
Benjamin Lu f668c2e3c9
bump comfyui-frontend-package to 1.34.8 (#11220) 2025-12-09 22:27:07 -05:00
comfyanonymous fc657f471a ComfyUI version v0.4.0
From now on ComfyUI will do version numbers a bit differently, every stable
off the master branch will increment the minor version. Anytime a fix needs
to be backported onto a stable version the patch version will be
incremented.

Example: We release v0.6.0 off the master branch then a day later a bug is
discovered and we decide to backport the fix onto the v0.6.0 stable, this
will be done in a separate branch in the main repository and this new
stable will be tagged v0.6.1
2025-12-09 18:26:49 -05:00
comfyanonymous 791e30ff50
Fix nan issue when quantizing fp16 tensor. (#11213) 2025-12-09 17:03:21 -05:00
Jukka Seppänen e2a800e7ef
Fix for HunyuanVideo1.5 meanflow distil (#11212) 2025-12-09 16:59:16 -05:00
rattus 9d252f3b70
ops: delete dead code (#11204)
This became dead code in https://github.com/comfyanonymous/ComfyUI/pull/11069
2025-12-09 00:55:13 -05:00
Lodestone b9fb542703
add chroma-radiance-x0 mode (#11197) 2025-12-08 23:33:29 -05:00
Christian Byrne cabc4d351f
bump comfyui-frontend-package to 1.33.13 (patch) (#11200) 2025-12-08 23:22:02 -05:00
rattus e136b6dbb0
dequantization offload accounting (fixes Flux2 OOMs - incl TEs) (#11171)
* make setattr safe for non existent attributes

Handle the case where the attribute doesnt exist by returning a static
sentinel (distinct from None). If the sentinel is passed in as the set
value, del the attr.

* Account for dequantization and type-casts in offload costs

When measuring the cost of offload, identify weights that need a type
change or dequantization and add the size of the conversion result
to the offload cost.

This is mutually exclusive with lowvram patches which already has
a large conservative estimate and wont overlap the dequant cost so\
dont double count.

* Set the compute type on CLIP MPs

So that the loader can know the size of weights for dequant accounting.
2025-12-08 23:21:31 -05:00
comfyanonymous d50f342c90
Fix potential issue. (#11201) 2025-12-08 23:20:04 -05:00
comfyanonymous 3b0368aa34
Fix regression. (#11194) 2025-12-08 17:38:36 -05:00
ComfyUI Wiki 935493f6c1
chore: update workflow templates to v0.7.54 (#11192) 2025-12-08 15:18:53 -05:00
rattus 60ee574748
retune lowVramPatch VRAM accounting (#11173)
In the lowvram case, this now does its math in the model dtype in the
post de-quantization domain. Account for that. The patching was also
put back on the compute stream getting it off-peak so relax the
MATH_FACTOR to only x2 so get out of the worst-case assumption of
everything peaking at once.
2025-12-08 15:18:06 -05:00
dxqb 8e889c535d
Support "transformer." LoRA prefix for Z-Image (#11135) 2025-12-08 15:17:26 -05:00
Alexander Piskun fd271dedfd
[API Nodes] add support for seedance-1-0-pro-fast model (#10947)
* feat(api-nodes): add support for seedance-1-0-pro-fast model

* feat(api-nodes): add support for seedream-4.5 model
2025-12-08 01:33:46 -08:00
Alexander Piskun c3c6313fc7
Added "system_prompt" input to Gemini nodes (#11177) 2025-12-08 01:28:17 -08:00
Alexander Piskun 85c4b4ae26
chore: replace imports of deprecated V1 classes (#11127) 2025-12-08 01:27:02 -08:00
ComfyUI Wiki 058f084371
Update workflow templates to v0.7.51 (#11150)
* chore: update workflow templates to v0.7.50

* Update template to 0.7.51
2025-12-08 01:22:51 -08:00
Alexander Piskun ec7f65187d
chore(comfy_api): replace absolute imports with relative (#11145) 2025-12-08 01:21:41 -08:00
comfyanonymous 56fa7dbe38
Properly load the newbie diffusion model. (#11172)
There is still one of the text encoders missing and I didn't actually test it.
2025-12-07 07:44:55 -05:00
comfyanonymous 329480da5a
Fix qwen scaled fp8 not working with kandinsky. Make basic t2i wf work. (#11162) 2025-12-06 17:50:10 -08:00
rattus 4086acf3c2
Fix on-load VRAM OOM (#11144)
slow down the CPU on model load to not run ahead. This fixes a VRAM on
flux 2 load.

I went to try and debug this with the memory trace pickles, which needs
--disable-cuda-malloc which made the bug go away. So I tried this
synchronize and it worked.

The has some very complex interactions with the cuda malloc async and
I dont have solid theory on this one yet.

Still debugging but this gets us over the OOM for the moment.
2025-12-06 18:42:09 -05:00
comfyanonymous 50ca97e776
Speed up lora compute and lower memory usage by doing it in fp16. (#11161) 2025-12-06 18:36:20 -05:00
Jukka Seppänen 7ac7d69d94
Fix EmptyAudio node input types (#11149) 2025-12-06 10:09:44 -08:00
Alexander Piskun 76f18e955d
marked all Pika API nodes a deprecated (#11146) 2025-12-06 03:28:08 -08:00
comfyanonymous d7a0aef650
Set OCL_SET_SVM_SIZE on AMD. (#11139) 2025-12-06 00:15:21 -05:00
Alexander Piskun 913f86b727
[V3] convert nodes_mask.py to V3 schema (#10669)
* convert nodes_mask.py to V3 schema

* set "Preview Mask" as display name for MaskPreview
2025-12-05 20:24:10 -08:00
Alexander Piskun 117bf3f2bd
convert nodes_freelunch.py to the V3 schema (#10904) 2025-12-05 20:22:02 -08:00
comfyanonymous ae676ed105
Fix regression. (#11137) 2025-12-05 23:01:19 -05:00
Jukka Seppänen fd109325db
Kandinsky5 model support (#10988)
* Add Kandinsky5 model support

lite and pro T2V tested to work

* Update kandinsky5.py

* Fix fp8

* Fix fp8_scaled text encoder

* Add transformer_options for attention

* Code cleanup, optimizations, use fp32 for all layers originally at fp32

* ImageToVideo -node

* Fix I2V, add necessary latent post process nodes

* Support text to image model

* Support block replace patches (SLG mostly)

* Support official LoRAs

* Don't scale RoPE for lite model as that just doesn't work...

* Update supported_models.py

* Rever RoPE scaling to simpler one

* Fix typo

* Handle latent dim difference for image model in the VAE instead

* Add node to use different prompts for clip_l and qwen25_7b

* Reduce peak VRAM usage a bit

* Further reduce peak VRAM consumption by chunking ffn

* Update chunking

* Update memory_usage_factor

* Code cleanup, don't force the fp32 layers as it has minimal effect

* Allow for stronger changes with first frames normalization

Default values are too weak for any meaningful changes, these should probably be exposed as advanced node options when that's available.

* Add image model's own chat template, remove unused image2video template

* Remove hard error in ReplaceVideoLatentFrames -node

* Update kandinsky5.py

* Update supported_models.py

* Fix typos in prompt template

They were now fixed in the original repository as well

* Update ReplaceVideoLatentFrames

Add tooltips
Make source optional
Better handle negative index

* Rename NormalizeVideoLatentFrames -node

For bit better clarity what it does

* Fix NormalizeVideoLatentStart node out on non-op
2025-12-05 22:20:22 -05:00
Dr.Lt.Data bed12674a1
docs: add ComfyUI-Manager documentation and update to v4.0.3b4 (#11133)
- Add manager setup instructions and command line options to README
- Document --enable-manager, --enable-manager-legacy-ui, and
  --disable-manager-ui flags
- Bump comfyui_manager version from 4.0.3b3 to 4.0.3b4
2025-12-05 15:45:38 -08:00
comfyanonymous 092ee8a500
Fix some custom nodes. (#11134) 2025-12-05 18:25:31 -05:00
Jukka Seppänen 79d17ba233
Context windows fixes and features (#10975)
* Apply cond slice fix

* Add FreeNoise

* Update context_windows.py

* Add option to retain condition by indexes for each window

This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations.

* Update context_windows.py

* Allow splitting multiple conds into different windows

* Add handling for audio_embed

* whitespace

* Allow freenoise to work on other dims, handle 4D batch timestep

Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now.

* Disable experimental options for now

So that  the Freenoise and bugfixes can be merged first

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-05 12:42:46 -08:00
comfyanonymous 6fd463aec9
Fix regression when text encoder loaded directly on GPU. (#11129) 2025-12-05 15:33:16 -05:00
comfyanonymous 43071e3de3
Make old scaled fp8 format use the new mixed quant ops system. (#11000) 2025-12-05 14:35:42 -05:00
Jedrzej Kosinski 0ec05b1481
Remove line made unnecessary (and wrong) after transformer_options was added to NextDiT's _forward definition (#11118) 2025-12-05 14:05:38 -05:00
comfyanonymous 35fa091340
Forgot to put this in README. (#11112) 2025-12-04 22:52:09 -05:00
Alexander Piskun 3c8456223c
[API Nodes]: fixes and refactor (#11104)
* chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder

* chore(api-nodes): add validate_video_frame_count function from LTX PR

* chore(api-nodes): replace deprecated V1 imports

* fix(api-nodes): the types returned by the "poll_op" function are now correct.
2025-12-04 14:05:28 -08:00
rattus 9bc893c5bb
sd: bump HY1.5 VAE estimate (#11107)
Im able to push vram above estimate on partial unload. Bump the
estimate. This is experimentally determined with a 720P and 480P
datapoint calibrating for 24GB VRAM total.
2025-12-04 09:50:36 -08:00
rattus f4bdf5f830
sd: revise hy VAE VRAM (#11105)
This was recently collapsed down to rolling VAE through temporal. Clamp
The time dimension.
2025-12-04 09:50:04 -08:00
rattus 6be85c7920
mp: use look-ahead actuals for stream offload VRAM calculation (#11096)
TIL that the WAN TE has a 2GB weight followed by 16MB as the next size
down. This means that team 8GB VRAM would fully offload the TE in async
offload mode as it just multiplied this giant size my the num streams.

Do the more complex logic of summing up the upcoming to-load weight
sizes to avoid triple counting this massive weight.

partial unload does the converse of recording the NS most recent
unloads as they go.
2025-12-03 23:28:44 -05:00
comfyanonymous ea17add3c6
Fix case where text encoders where running on the CPU instead of GPU. (#11095) 2025-12-03 23:15:15 -05:00
comfyanonymous ecdc8697d5
Qwen Image Lora training fix from #11090 (#11094) 2025-12-03 22:49:28 -05:00
Alexander Piskun dce518c2b4
convert nodes_audio.py to V3 schema (#10798) 2025-12-03 17:35:04 -08:00
Alexander Piskun 440268d394
convert nodes_load_3d.py to V3 schema (#10990) 2025-12-03 13:52:31 -08:00
Alexander Piskun 87c104bfc1
add support for "@image" reference format in Kling Omni API nodes (#11082) 2025-12-03 08:55:44 -08:00
Alexander Piskun 19f2192d69
fix(V3-Schema): use empty list defaults for Schema.inputs/outputs/hidden to avoid None issues (#11083) 2025-12-03 08:37:35 -08:00
rattus 519c941165
Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069)
* mp: only count the offload cost of math once

This was previously bundling the combined weight storage and computation
cost

* ops: put all post async transfer compute on the main stream

Some models have massive weights that need either complex
dequantization or lora patching. Don't do these patchings on the offload
stream, instead do them on the main stream to syncrhonize the
potentially large vram spikes for these compute processes. This avoids
having to assume a worst case scenario of multiple offload streams
all spiking VRAM is parallel with whatever the main stream is doing.
2025-12-03 02:28:45 -05:00
comfyanonymous 861817d22d
Fix issue with portable updater. (#11070)
This should fix the problem with the portable updater not working with portables created from a separate branch on the repo.

This does not affect any current portables who were all created on the master branch.
2025-12-03 00:47:51 -05:00
Jedrzej Kosinski c120eee5ba
Add MatchType, DynamicCombo, and Autogrow support to V3 Schema (#10832)
* Added output_matchtypes to generated json for v3, initial backend support for MatchType, created nodes_logic.py and added SwitchNode

* Fixed providing list of allowed_types

* Add workaround in validation.py for V3 Combo outputs not working as Combo inputs

* Make match type receive_type pass validation

* Also add MatchType check to input_type in validation - will likely trigger when connecting to non-lazy stuff

* Make sure this PR only has MatchType stuff

* Initial work on DynamicCombo

* Add get_dynamic function, not yet filled out correctly

* Mark Switch node as Beta

* Make sure other unfinished dynamic types are not accidentally used

* Send DynamicCombo.Option inputs in the same format as normal v1 inputs

* add dynamic combo test node

* Support validation of inputs and outputs

* Add missing input params to DynamicCombo.Input

* Add get_all function to inputs for id validation purposes

* Fix imports for v3 returning everything when doing io/ui/IO/UI instead of what is in __all__ of _io.py and _ui.py

* Modifying behavior of get_dynamic in V3 + serialization so can be used in execution code

* Fix v3 schema validation code after changes

* Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet)

* Add nesting of inputs on DynamicCombo during execution

* Work with latest frontend commits

* Fix cringe arrows

* frontend will no longer namespace dynamic inputs widgets so reflect that in code, refactor build_nested_inputs

* Prepare Autogrow support for the love of the game

* satisfy ruff

* Create test nodes for Autogrow to collab with frontend development

* Add nested combo to DCTestNode

* Remove array support from build_nested_inputs, properly handle missing expected values

* Make execution.validate_inputs properly validate required dynamic inputs, renamed dynamic_data to dynamic_paths for clarity

* MatchType does not need any DynamicInput/Output features on backend; will increase compatibility with  dynamic types

* Probably need this for ruff check

* Change MatchType to have template be the first and only required param; output id's do nothing right now, so no need

* Fix merge regression with LatentUpscaleModel type not being put in __all__ for _io.py, fix invalid type hint for validate_inputs

* Make Switch node inputs optional, disallow both inputs from being missing, and still work properly with lazy; when one input is missing, use the other no matter what the switch is set to

* Satisfy ruff

* Move MatchType code above the types that inherit from DynamicInput

* Add DynamicSlot type, awaiting frontend support

* Make curr_prefix creation happen in Autogrow, move curr_prefix in DynamicCombo to only be created if input exists in live_inputs

* I was confused, fixing accidentally redundant curr_prefix addition in Autogrow

* Make sure Autogrow inputs are force_input = True when WidgetInput, fix runtime validation by removing original input from expected inputs, fix min/max bounds, change test nodes slightly

* Remove unnecessary id usage in Autogrow test node outputs

* Commented out Switch node + test nodes

* Remove commented out code from Autogrow

* Make TemplatePrefix max more clear, allow max == 1

* Replace all dict[str] with dict[str, Any]

* Renamed add_to_dict_live_inputs to expand_schema_for_dynamic

* Fixed typo in DynamicSlot input code

* note about live_inputs not being present soon in get_v1_info (internal function anyway)

* For now, hide DynamicCombo and Autogrow from public interface

* Removed comment
2025-12-03 00:17:13 -05:00
rattus 73f5649196
Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995)
* hunyuan upsampler: rework imports

Remove the transitive import of VideoConv3d and Resnet and takes these
from actual implementation source.

* model: remove unused give_pre_end

According to git grep, this is not used now, and was not used in the
initial commit that introduced it (see below).

This semantic is difficult to implement temporal roll VAE for (and would
defeat the purpose). Rather than implement the complex if, just delete
the unused feature.

(venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline
220afe33 (HEAD) Initial commit.
(venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre
comfy/ldm/modules/diffusionmodules/model.py:                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
comfy/ldm/modules/diffusionmodules/model.py:        self.give_pre_end = give_pre_end
comfy/ldm/modules/diffusionmodules/model.py:        if self.give_pre_end:

(venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master
Previous HEAD position was 220afe33 Initial commit.
HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953)
(venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre
comfy/ldm/modules/diffusionmodules/model.py:                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
comfy/ldm/modules/diffusionmodules/model.py:        self.give_pre_end = give_pre_end
comfy/ldm/modules/diffusionmodules/model.py:        if self.give_pre_end:

* move refiner VAE temporal roller to core

Move the carrying conv op to the common VAE code and give it a better
name. Roll the carry implementation logic for Resnet into the base
class and scrap the Hunyuan specific subclass.

* model: Add temporal roll to main VAE decoder

If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolloing VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).

* model: Add temporal roll to main VAE encoder

If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolling VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).
2025-12-02 22:49:29 -05:00
Jim Heising 3f512f5659
Added PATCH method to CORS headers (#11066)
Added PATCH http method to access-control-allow-header-methods header because there are now PATCH endpoints exposed in the API.

See 277237ccc1/api_server/routes/internal/internal_routes.py (L34) for an example of an API endpoint that uses the PATCH method.
2025-12-02 22:29:27 -05:00
comfyanonymous b94d394a64
Support Z Image alibaba pai fun controlnets. (#11062)
These are not actual controlnets so put it in the models/model_patches
folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to
use it.
2025-12-02 21:38:31 -05:00
rattus 277237ccc1
attention: use flag based OOM fallback (#11038)
Exception ref all local variables for the lifetime of exception
context. Just set a flag and then if to dump the exception before
falling back.
2025-12-02 17:24:19 -05:00
comfyanonymous daaceac769
Hack to make zimage work in fp16. (#11057) 2025-12-02 17:11:58 -05:00
Alexander Piskun 33d6aec3b7
add check for the format arg type in VideoFromComponents.save_to function (#11046)
* add check for the format var type in VideoFromComponents.save_to function

* convert "format" to VideoContainer enum
2025-12-02 11:50:13 -08:00
Jedrzej Kosinski 44baa0b7f3
Fix CODEOWNERS formatting to have all on the same line, otherwise only last line applies (#11053) 2025-12-02 11:46:29 -08:00
Yoland Yan a17cf1c387
Add @guill as a code owner (#11031) 2025-12-01 22:40:44 -05:00
Dr.Lt.Data b4a20acc54
feat: Support ComfyUI-Manager for pip version (#7555) 2025-12-01 22:32:52 -05:00
Christian Byrne c55dc857d5
bump comfyui-frontend-package to 1.33.10 (#11028) 2025-12-01 20:56:38 -05:00
comfyanonymous 878db3a727
Implement the Ovis image model. (#11030) 2025-12-01 20:56:17 -05:00
156 changed files with 9863 additions and 3758 deletions

View File

@ -53,6 +53,16 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:
@ -66,8 +76,10 @@ if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
print("pulling.") # noqa: T201
pull(repo)
print("fetching.") # noqa: T201
for remote in repo.remotes:
if remote.name == "origin":
remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
@ -149,3 +161,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@ -5,6 +5,7 @@ on:
push:
branches:
- master
- release/**
paths-ignore:
- 'app/**'
- 'input/**'

View File

@ -2,9 +2,9 @@ name: Execution Tests
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Test server launches without errors
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Unit Tests
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -6,6 +6,7 @@ on:
- "pyproject.toml"
branches:
- master
- release/**
jobs:
update-version:

View File

@ -1,3 +1,2 @@
# Admins
* @comfyanonymous
* @kosinkadink
* @comfyanonymous @kosinkadink @guill

View File

@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@ -118,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
@ -208,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
### Instructions:
Git clone this repo.
@ -319,6 +325,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
### Setup
1. Install the manager dependencies:
```bash
pip install -r manager_requirements.txt
```
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
```bash
python main.py --enable-manager
```
### Command Line Options
| Flag | Description |
|------|-------------|
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
# Running
```python main.py```

View File

@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -44,7 +44,7 @@ class ModelFileManager:
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
if folder not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@ -55,7 +55,7 @@ class ModelFileManager:
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]

View File

@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
@ -121,6 +128,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
manager_group = parser.add_mutually_exclusive_group()
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
@ -168,6 +181,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

View File

@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()

View File

@ -1,6 +1,5 @@
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
import json
import logging
@ -17,24 +16,7 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
@ -73,7 +55,7 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()

View File

@ -51,32 +51,43 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
idx = [slice(None)] * dim + [self.index_list]
return full[idx].to(device)
idx = tuple([slice(None)] * dim + [self.index_list])
window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
window[idx] = full[idx]
return window.to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
idx = [slice(None)] * dim + [self.index_list]
idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add
return full
def get_region_index(self, num_regions: int) -> int:
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self):
return {}
@ -94,7 +105,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -103,13 +115,18 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True
return False
@ -123,6 +140,11 @@ class IndexListContextHandler(ContextHandlerABC):
return None
# reuse or resize cond items to match context requirements
resized_cond = []
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
@ -145,13 +167,38 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# Handle vace_context (temporal dim is 3)
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
vace_cond = cond_value.cond
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
@ -164,7 +211,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@ -173,7 +220,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@ -250,8 +297,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx]
pos_window = [slice(None)] * self.dim + [pos]
idx_window = tuple([slice(None)] * self.dim + [idx])
pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
@ -287,6 +334,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
)
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
model_options = extra_args.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is None:
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample",
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@ -538,3 +607,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise")
generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[dim]
delta = context_length - context_overlap
for start_idx in range(0, latent_video_length - context_length, delta):
place_idx = start_idx + context_length
actual_delta = min(delta, latent_video_length - place_idx)
if actual_delta <= 0:
break
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
source_slice = [slice(None)] * noise.ndim
source_slice[dim] = list_idx
target_slice = [slice(None)] * noise.ndim
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
noise[tuple(target_slice)] = noise[tuple(source_slice)]
return noise

View File

@ -527,7 +527,8 @@ class HookKeyframeGroup:
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
else:
break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on

View File

@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
def default_noise_sampler(x, seed=None):
if seed is not None:
if x.device == torch.device("cpu"):
seed += 1
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
else:
@ -1557,10 +1560,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1606,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
@ -1609,6 +1621,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
@ -1756,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
x_pred = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@ -1777,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
return x_pred
@torch.no_grad()

View File

@ -40,7 +40,8 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
txt_ids_dims: list
vec_in_dim: int

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma):
"""
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = []
self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -267,7 +270,7 @@ class ChromaRadiance(Chroma):
bad_keys = tuple(
k
for k, v in overrides.items()
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
)
if bad_keys:
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides
return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward(
self,
x: Tensor,
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

View File

@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class YakMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
self.act_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
if yak_mlp:
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if mlp_silu_act:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if mlp_silu_act:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
modulation=True,
mlp_silu_act=False,
bias=True,
yak_mlp=False,
dtype=None,
device=None,
operations=None
@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
self.yak_mlp = yak_mlp
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
if self.yak_mlp:
self.mlp_hidden_dim_first *= 2
self.mlp_act = nn.SiLU()
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
mlp = self.mlp_act(mlp)
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
else:
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:

View File

@ -15,7 +15,8 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
Modulation
Modulation,
RMSNorm
)
@dataclass
@ -34,11 +35,14 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
txt_ids_dims: list
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
yak_mlp: bool = False
txt_norm: bool = False
class Flux(nn.Module):
@ -76,6 +80,11 @@ class Flux(nn.Module):
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
else:
self.txt_norm = None
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@ -86,6 +95,7 @@ class Flux(nn.Module):
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -94,7 +104,7 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@ -150,6 +160,8 @@ class Flux(nn.Module):
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.txt_norm is not None:
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
@ -332,8 +344,9 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
if len(self.params.txt_ids_dims) > 0:
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]

View File

@ -43,6 +43,7 @@ class HunyuanVideoParams:
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
meanflow_sum: bool
class SelfAttentionRef(nn.Module):
@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)

View File

@ -1,8 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management
import model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):

View File

@ -1,42 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module):
def __init__(self, dim):
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = NoPadConv3d
conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@ -206,9 +159,10 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
out = torch_cat_if_needed(out, dim=2)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = NoPadConv3d
conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@ -308,9 +260,10 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out
del x
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
out = torch_cat_if_needed(out, dim=2)
if not self.refiner_vae:
if z.shape[-3] == 1:

View File

@ -0,0 +1,413 @@
import torch
from torch import nn
import math
import comfy.ldm.common_dit
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
def attention(q, k, v, heads, transformer_options={}):
return optimized_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
heads=heads,
skip_reshape=True,
transformer_options=transformer_options
)
def apply_scale_shift_norm(norm, x, scale, shift):
return torch.addcmul(shift, norm(x), scale + 1.0)
def apply_gate_sum(x, out, gate):
return torch.addcmul(x, gate, out)
def get_shift_scale_gate(params):
shift, scale, gate = torch.chunk(params, 3, dim=-1)
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
def get_freqs(dim, max_period=10000.0):
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.SiLU()
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, timestep, dtype):
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, text_embed):
text_embed = self.in_layer(text_embed)
return self.norm(text_embed).type_as(text_embed)
class VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
x = x.movedim(1, -1) # B C T H W -> B T H W C
B, T, H, W, dim = x.shape
pt, ph, pw = self.patch_size
x = x.view(
B,
T // pt, pt,
H // ph, ph,
W // pw, pw,
dim,
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
return self.in_layer(x)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
return self.out_layer(self.activation(x))
class SelfAttention(nn.Module):
def __init__(self, num_channels, head_dim, operation_settings=None):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.head_dim = head_dim
operations = operation_settings.get("operations")
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 2
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
chunks = []
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1)
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
class CrossAttention(SelfAttention):
def get_qkv(self, x, context):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
return q, k, v
def forward(self, x, context, transformer_options={}):
q, k, v = self.get_qkv(x, context)
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.GELU()
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 4
def _forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x):
chunks = torch.chunk(x, self.num_chunks, dim=1)
output_chunks = []
for chunk in chunks:
output_chunks.append(self._forward(chunk))
return torch.cat(output_chunks, dim=1)
def forward(self, x):
if x.shape[1] > 8192:
return self._forward_chunked(x)
else:
return self._forward(x)
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, visual_embed, time_embed):
B, T, H, W, _ = visual_embed.shape
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
scale = scale[:, None, None, None, :]
shift = shift[:, None, None, None, :]
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
x = self.out_layer(visual_embed)
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
x = x.view(
B, T, H, W,
out_dim,
self.patch_size[0], self.patch_size[1], self.patch_size[2]
)
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, x, time_embed, freqs, transformer_options={}):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = get_shift_scale_gate(self_attn_params)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, freqs, transformer_options=transformer_options)
x = apply_gate_sum(x, out, gate)
shift, scale, gate = get_shift_scale_gate(ff_params)
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = self.feed_forward(out)
x = apply_gate_sum(x, out, gate)
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# feed forward
shift, scale, gate = get_shift_scale_gate(ff_params)
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
class Kandinsky5(nn.Module):
def __init__(
self,
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
dtype=None, device=None, operations=None, **kwargs
):
super().__init__()
head_dim = sum(axes_dims)
self.rope_scale_factor = rope_scale_factor
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_embed_dim = visual_embed_dim
self.dtype = dtype
self.device = device
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
self.text_transformer_blocks = nn.ModuleList(
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
)
self.visual_transformer_blocks = nn.ModuleList(
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
)
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
steps = seq_len if steps is None else steps
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
return freqs
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
else:
rope_scale_factor = self.rope_scale_factor
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
if h * w >= 14080:
rope_scale_factor = (1.0, 3.16, 3.16)
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
return freqs
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
original_dims = x.ndim
if original_dims == 4:
x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if time_dim_replace is not None:
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
if original_dims == 4:
out = out.squeeze(2)
return out
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -0,0 +1,160 @@
import torch
from torch import nn
from .model import JointTransformerBlock
class ZImageControlTransformerBlock(JointTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0,
operation_settings=None,
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
self.block_id = block_id
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class ZImage_Control(torch.nn.Module):
def __init__(
self,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
multiple_of: int = 256,
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
n_control_layers=6,
control_in_dim=16,
additional_in_dim=0,
broken=False,
refiner_control=False,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.broken = broken
self.additional_in_dim = additional_in_dim
self.control_in_dim = control_in_dim
n_refiner_layers = 2
self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=i,
operation_settings=operation_settings,
)
for i in range(self.n_control_layers)
]
)
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.refiner_control = refiner_control
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
if self.refiner_control:
self.control_noise_refiner = nn.ModuleList(
[
ZImageControlTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=layer_id,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
else:
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
f_patch_size = 1
pH = pW = patch_size
B, C, H, W = control_context.shape
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
if not self.refiner_control:
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
if self.refiner_control:
if self.broken:
if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0:
out = None
for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if out is None:
out = o
return (out, control_context)
else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
else:
return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -22,6 +22,10 @@ def modulate(x, scale):
# Core NextDiT Model #
#############################################################################
def clamp_fp16(x):
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
@ -169,7 +173,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
return clamp_fp16(F.silu(x1) * x3)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(
clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
))
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(
clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
)
))
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
self.attention(
clamp_fp16(self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
))
)
x = x + self.ffn_norm2(
self.feed_forward(
@ -373,6 +377,7 @@ class NextDiT(nn.Module):
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@ -443,6 +448,31 @@ class NextDiT(nn.Module):
),
)
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@ -461,7 +491,8 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers)
]
)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
@ -506,6 +537,7 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@ -542,13 +574,21 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None
for layer in self.noise_refiner:
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
@ -564,7 +604,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@ -581,16 +621,36 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
img = self.final_layer(img, adaln_input)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -x
return -img

View File

@ -30,6 +30,13 @@ except ImportError as e:
raise e
exit(-1)
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
@ -517,6 +524,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@ -541,6 +549,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
@ -560,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head)
return out
@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads
if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape
try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
@ -647,6 +744,8 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():

View File

@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class CarriedConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, CarriedConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
out = op(x)
return out
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1,
padding=1)
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
results = []
if conv_carry_in is None:
first = x[:, :, :1, :, :]
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
x = x[:, :, 1:, :, :]
if x.shape[2] > 0:
results.append(interpolate_up(x, scale_factor))
x = torch_cat_if_needed(results, dim=2)
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
x = self.conv(x)
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x
@ -127,17 +159,20 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv:
if x.ndim == 4:
if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x, temb=None):
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x
h = self.norm1(h)
h = self.swish(h)
h = self.conv1(h)
h = [ self.swish(h) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = self.swish(h)
h = self.dropout(h)
h = self.conv2(h)
h = [ self.dropout(h) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else:
x = self.nin_shortcut(x)
@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
@ -356,7 +394,8 @@ class Model(nn.Module):
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
@ -510,16 +549,22 @@ class Encoder(nn.Module):
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.carried = False
if conv3d:
conv_op = VideoConv3d
if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@ -532,6 +577,7 @@ class Encoder(nn.Module):
stride=1,
padding=1)
self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@ -558,10 +604,15 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -587,15 +638,42 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h)
if self.carried:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.time_compress:
tc = self.time_compress
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
x = xl
else:
x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling
x1 = [ x1 ]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions-1:
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# middle
h = self.mid.block_1(h, temb)
@ -604,15 +682,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h = [ nonlinearity(h) ]
h = conv_carry_causal_3d(h, self.conv_out)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@ -626,12 +704,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.carried = False
if conv3d:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
if not attn_resolutions and resnet_op == ResnetBlock:
conv_op = CarriedConv3d
conv_out_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@ -706,29 +790,43 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h = self.conv_in(z)
h = conv_carry_causal_3d([z], self.conv_in)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
if self.carried:
h = torch.split(h, 2, dim=2)
else:
h = [ h ]
out = []
conv_carry_in = None
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
for i, h1 in enumerate(h):
conv_carry_out = []
if i == len(h) - 1:
conv_carry_out = None
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
if len(self.up[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
if i_level != 0:
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
# end
if self.give_pre_end:
return h
h1 = self.norm_out(h1)
h1 = [ nonlinearity(h1) ]
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
if self.tanh_out:
h1 = torch.tanh(h1)
out.append(h1)
conv_carry_in = conv_carry_out
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h
out = torch_cat_if_needed(out, dim=2)
return out

View File

@ -45,7 +45,7 @@ class LitEma(nn.Module):
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
@ -54,7 +54,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name
def store(self, parameters):
"""

View File

@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations
)
def forward(self, timestep, hidden_states):
self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb
@ -218,9 +228,24 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@ -229,14 +254,19 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
@ -251,15 +281,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -300,10 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
use_additional_t_cond=False,
dtype=None,
device=None,
operations=None,
@ -314,12 +345,14 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype,
device=device,
operations=operations
@ -341,6 +374,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
if self.default_ref_method == "index_timestep_zero":
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
@ -350,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward(
self,
@ -378,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
additional_t_cond=None,
transformer_options={},
control=None,
**kwargs
@ -391,16 +433,24 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
@ -415,6 +465,10 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -426,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
@ -446,7 +493,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
@ -458,6 +505,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options,
)
@ -474,9 +522,12 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -71,7 +71,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config):
if not "target" in config:
if "target" not in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":

View File

@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
input_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
output_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
CausalConv3d(out_dim, output_channels, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
image_channels=3,
dropout=0.0):
super().__init__()
self.dim = dim
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):

View File

@ -320,8 +320,16 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
return key_map

View File

@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.model_management
import comfy.patcher_extension
@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@ -329,18 +330,6 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@ -1121,6 +1110,10 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
return out
class WAN21(BaseModel):
@ -1642,3 +1635,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out
class Kandinsky5(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
device = kwargs["device"]
image = torch.zeros_like(noise)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
time_dim_replace = kwargs.get("time_dim_replace", None)
if time_dim_replace is not None:
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
return out
class Kandinsky5Image(Kandinsky5):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
return None

View File

@ -6,20 +6,6 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@ -194,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
dit_config["meanflow_sum"] = False
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
@ -208,12 +196,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
dit_config["txt_ids_dims"] = [3]
patch_size = 1
else:
dit_config["image_model"] = "flux"
@ -223,6 +211,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
dit_config["txt_ids_dims"] = []
patch_size = 2
dit_config["in_channels"] = 16
@ -245,6 +234,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
else:
dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@ -268,8 +259,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
@ -429,6 +429,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
@ -615,6 +619,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
if model_dim == 2560: # lite image
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["image_model"] = "kandinsky5"
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
@ -759,22 +786,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config:
model_config.quant_config = quant_config
logging.info("Detected mixed precision quantization")
return model_config

View File

@ -26,6 +26,7 @@ import importlib
import platform
import weakref
import gc
import os
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -333,13 +334,15 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
@ -1016,8 +1019,8 @@ NUM_STREAMS = 0
if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia
if is_nvidia():
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
NUM_STREAMS = 2
if args.disable_async_offload:
@ -1123,6 +1126,16 @@ if not args.disable_pinned_memory:
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
torch.cuda.synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
@ -1155,6 +1168,9 @@ def pin_memory(tensor):
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
else:
logging.warning("Pin error.")
discard_cuda_async_error()
return False
@ -1183,6 +1199,9 @@ def unpin_memory(tensor):
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
else:
logging.warning("Unpin error.")
discard_cuda_async_error()
return False
@ -1492,6 +1511,20 @@ def extended_fp16_support():
return True
LORA_COMPUTE_DTYPES = {}
def lora_compute_dtype(device):
dtype = LORA_COMPUTE_DTYPES.get(device, None)
if dtype is not None:
return dtype
if should_use_fp16(device):
dtype = torch.float16
else:
dtype = torch.float32
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@ -1509,6 +1542,10 @@ def soft_empty_cache(force=False):
def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else
import threading

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -126,36 +127,23 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight, inplace=False)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
if model_dtype is None:
model_dtype = weight.dtype
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@ -466,6 +454,9 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
@ -630,10 +621,11 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
@ -677,12 +669,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
def check_module_offload_mem(key):
if key in self.patches:
return low_vram_patch_estimate_vram(self.model, key)
model_dtype = getattr(self.model, "manual_cast_dtype", None)
weight, _, _ = get_key_weight(self.model, key)
if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
@ -699,12 +697,12 @@ class ModelPatcher:
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for x in loading:
for i, x in enumerate(loading):
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@ -777,6 +775,8 @@ class ModelPatcher:
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
torch.cuda.synchronize()
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@ -876,14 +876,18 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
if len(unload_list) > 0:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -929,12 +933,14 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
if cast_weight:
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
for param in params:

View File

@ -22,7 +22,7 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
import json
def run_every_op():
if torch.compiler.is_compiling():
@ -93,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
@ -111,22 +104,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
comfy.model_management.sync_stream(device, offload_stream)
bias_a = bias
weight_a = weight
if s.bias is not None:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
return weight, bias, offload_stream
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
@ -135,13 +130,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
os, weight_a, bias_a = offload_stream
if os is None:
return
if weight_a is not None:
device = weight_a.device
else:
if bias is None:
if bias_a is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
@ -417,22 +415,12 @@ def fp8_linear(self, input):
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@ -453,7 +441,7 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
if not self.training:
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
try:
out = fp8_linear(self, input)
if out is not None:
@ -466,59 +454,6 @@ class fp8_ops(manual_cast):
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
@ -545,9 +480,9 @@ if CUBLAS_IS_AVAILABLE:
from .quant_ops import QuantizedTensor, QUANT_ALGOS
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_layer_quant_config = layer_quant_config
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
@ -562,15 +497,14 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self._has_bias = bias
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -590,36 +524,59 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format]
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
if scale is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@ -628,6 +585,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if key in missing_keys:
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@ -643,9 +610,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
@ -656,7 +622,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
@ -665,17 +631,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if (
fp8_compute and

View File

@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor):
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
@ -397,17 +398,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
tensor_info = torch.finfo(tensor.dtype)
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if scale is not None:
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)

View File

@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@ -720,7 +720,7 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
@ -984,9 +984,6 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
@ -1013,6 +1010,24 @@ class CFGGuider:
else:
latent_shapes = [latent_image.shape]
if denoise_mask is not None:
if denoise_mask.is_nested:
denoise_masks = denoise_mask.unbind()
denoise_masks = denoise_masks[:len(latent_shapes)]
else:
denoise_masks = [denoise_mask]
for i in range(len(denoise_masks), len(latent_shapes)):
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

View File

@ -53,6 +53,10 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.model_patcher
import comfy.lora
@ -97,7 +101,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:
return
params = target.params.copy()
@ -125,9 +129,32 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if len(state_dict) > 0:
if isinstance(state_dict, list):
for c in state_dict:
m, u = self.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
else:
m, u = self.load_sd(state_dict, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
@ -192,6 +219,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@ -239,6 +267,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
@ -294,6 +323,7 @@ class VAE:
self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
@ -408,6 +438,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
@ -468,7 +499,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@ -480,8 +511,10 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
#This is likely to significantly over-estimate with single image or low frame counts as the
#implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@ -517,11 +550,15 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
@ -551,6 +588,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
@ -659,17 +697,28 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
if self.crop_input:
downscale_ratio = self.spacial_compression_encode()
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
if pixels.shape[-1] > self.output_channels:
pixels = pixels[..., :self.output_channels]
elif pixels.shape[-1] < self.output_channels:
if self.pad_channel_value is not None:
if isinstance(self.pad_channel_value, str):
mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -740,6 +789,8 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -956,16 +1007,18 @@ class CLIPType(Enum):
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
NEWBIE = 24
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@ -987,6 +1040,8 @@ class TEModel(Enum):
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
JINA_CLIP_2 = 18
def detect_te_model(sd):
@ -996,6 +1051,8 @@ def detect_te_model(sd):
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
@ -1020,9 +1077,12 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if weight.shape[0] == 2560:
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@ -1078,7 +1138,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
@ -1102,7 +1162,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
@ -1131,7 +1191,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
@ -1150,13 +1210,19 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
@ -1199,6 +1265,23 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
elif clip_type == CLIPType.KANDINSKY5:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]
clip_data_jina = clip_data[1]
else:
clip_data_gemma = clip_data[1]
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -1211,19 +1294,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0
for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
return clip
def load_gligen(ckpt_path):
@ -1282,6 +1356,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
@ -1290,18 +1368,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
if model_config.quant_config is not None:
weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None)
if custom_operations is not None:
model_config.custom_operations = custom_operations
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
@ -1319,22 +1401,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
scaled_fp8_list = []
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
if k.endswith(".scaled_fp8"):
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
@ -1381,6 +1474,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
if len(temp_sd) > 0:
sd = temp_sd
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
@ -1411,7 +1507,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
if model_config.quant_config is not None:
weight_dtype = None
if dtype is None:
@ -1419,12 +1515,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else:
unet_dtype = dtype
if model_config.layer_quant_config is not None:
if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if custom_operations is not None:
model_config.custom_operations = custom_operations
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
@ -1463,6 +1562,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if vae is not None:
vae_sd = vae.get_sd()
if metadata is None:
metadata = {}
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)

View File

@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config[k] = v
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)
quant_config = model_options.get("quantization_metadata", None)
if operations is None:
layer_quant_config = None
if quantization_metadata is not None:
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
if quant_config is not None:
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
logging.info("Using MixedPrecisionOps for text encoder")
else:
# Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
if scaled_fp8 is not None:
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
@ -147,6 +135,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@ -163,6 +152,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
@ -175,6 +165,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@ -258,7 +249,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
if self.execution_device is None:
device = self.transformer.get_input_embeddings().weight.device
else:
device = self.execution_device
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
@ -471,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@ -518,6 +513,8 @@ class SDTokenizer:
self.embedding_size = embedding_size
self.embedding_key = embedding_key
self.disable_weights = disable_weights
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
@ -552,7 +549,7 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
if kwargs.get("disable_weights", False):
if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)]
else:
parsed_weights = token_weights(text, 1.0)

View File

@ -21,12 +21,14 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@ -540,7 +542,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."]
@ -964,7 +966,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@ -1025,7 +1027,15 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 1.7
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -1286,7 +1296,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038
memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
@ -1329,7 +1339,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
memory_usage_factor = 1.95 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1394,7 +1404,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1472,7 +1482,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
class Kandinsky5(supported_models_base.BASE):
unet_config = {
"image_model": "kandinsky5",
}
sampling_settings = {
"shift": 10.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
class Kandinsky5Image(Kandinsky5):
unet_config = {
"image_model": "kandinsky5",
"model_dim": 2560,
"visual_embed_dim": 64,
}
sampling_settings = {
"shift": 3.0,
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]

View File

@ -17,6 +17,7 @@
"""
import torch
import logging
from . import model_base
from . import utils
from . import latent_formats
@ -49,8 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
quant_config = None # quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod
@ -118,3 +118,7 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
def __getattr__(self, name):
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
return None

View File

@ -154,7 +154,8 @@ class TAEHV(nn.Module):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options["quantization_metadata"] = t5xxl_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
def te(dtype_t5=None, t5_quantization_metadata=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()

View File

@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
else:
return super().load_sd(sd)
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)

View File

@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
import torch
import os
import numbers
import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["llama_quantization_metadata"] = quant
return out
@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
return self.llama.load_sd(sd)
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_

View File

@ -0,0 +1,219 @@
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
from dataclasses import dataclass
import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
vocab_size: int = 250002
type_vocab_size: int = 1
hidden_size: int = 1024
num_hidden_layers: int = 24
num_attention_heads: int = 16
rotary_emb_base: float = 20000.0
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-05
bos_token_id: int = 0
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
def forward(self, input_ids=None, embeddings=None):
if input_ids is not None and embeddings is None:
embeddings = self.word_embeddings(input_ids)
if embeddings is not None:
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, q, k):
batch, seqlen, heads, head_dim = q.shape
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
def rotate_half(x):
size = x.shape[-1] // 2
x1, x2 = x[..., :size], x[..., size:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = embed_dim // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None, optimized_attention=None):
qkv = self.Wqkv(x)
batch_size, seq_len, _ = qkv.shape
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k = self.rotary_emb(q, k)
# NHD -> HND
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.activation = F.gelu
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states, mask=None, optimized_attention=None):
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
mlp_out = self.mlp(hidden_states)
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
sequence_output = self.encoder(x, attention_mask=mask)
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
pooled_output = None
if attention_mask is None:
pooled_output = sequence_output.mean(dim=1)
else:
attention_mask = attention_mask.to(sequence_output.dtype)
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.config = XLMRobertaConfig(**config_dict)
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.model.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -0,0 +1,68 @@
from comfy import sd1_clip
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from .llama import Qwen25_7BVLI
class Kandinsky5Tokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Kandinsky5TEModel(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
return cond, l_pooled, extra
def set_clip_options(self, options):
super().set_clip_options(options)
self.clip_l.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
self.clip_l.reset_clip_options()
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return super().load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class Kandinsky5TEModel_(Kandinsky5TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Kandinsky5TEModel_

View File

@ -3,13 +3,11 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
import comfy.model_management
from . import qwen_vl
@dataclass
@ -100,6 +98,28 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
@dataclass
class Ovis25_2BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
@ -155,7 +175,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
@ -164,8 +184,8 @@ class Gemma3_4B_Config:
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
class RMSNorm(nn.Module):
@ -348,7 +368,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
if config.sliding_attention is not None:
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
@ -365,7 +385,12 @@ class TransformerBlockGemma2(nn.Module):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
@ -542,6 +567,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Ovis25_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Ovis25_2BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
@ -40,7 +45,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
@ -48,9 +53,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)

View File

@ -0,0 +1,62 @@
import torch
import comfy.model_management
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.lumina2
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
raise NotImplementedError
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
self.jina.set_clip_options(options)
def reset_clip_options(self):
self.gemma.reset_clip_options()
self.jina.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_gemma = token_weight_pairs["gemma"]
token_weight_pairs_jina = token_weight_pairs["jina"]
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
return gemma_out, jina_pooled, gemma_extra
def load_sd(self, sd):
if "model.layers.0.self_attn.q_norm.weight" in sd:
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_

View File

@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_quantization_metadata=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -0,0 +1,66 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
import numbers
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
class OvisTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Ovis25_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
class OvisTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen3_2b"][0]
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 4004 and count_im_start < 1:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 1):
if tok_pairs[template_end + 1][0] == 25:
template_end += 1
out = out[:, template_end:]
return out, pooled, {}
def te(dtype_llama=None, llama_quantization_metadata=None):
class OvisTEModel_(OvisTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return OvisTEModel_

View File

@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
return out, pooled, extra
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -6,14 +6,15 @@ import torch
import os
import comfy.model_management
import logging
import comfy.utils
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options["quantization_metadata"] = t5xxl_quantization_metadata
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["t5_quantization_metadata"] = quant
return out
@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
def te(dtype_t5=None, t5_quantization_metadata=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options["quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
def te(dtype_llama=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:

View File

@ -29,6 +29,7 @@ import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -52,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
@ -802,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
ATTR_UNSET={}
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], value)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
else:
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
@ -1194,3 +1200,70 @@ def unpack_latents(combined_latent, latent_shapes):
else:
output_tensors = combined_latent
return output_tensors
def detect_layer_quantization(state_dict, prefix):
for k in state_dict:
if k.startswith(prefix) and k.endswith(".comfy_quant"):
logging.info("Found quantization metadata version 1")
return {"mixed_ops": True}
return None
def convert_old_quants(state_dict, model_prefix="", metadata={}):
if metadata is None:
metadata = {}
quant_metadata = None
if "_quantization_metadata" not in metadata:
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
out_sd = {}
layers = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
continue
out_sd[k_out] = w
state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata

View File

@ -5,19 +5,20 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from typing import Any
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
}
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sockets_metadata: dict[str, dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
@ -41,7 +42,7 @@ def get_connection_feature(
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sockets_metadata: dict[str, dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
@ -59,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
def get_server_features() -> dict[str, Any]:
"""
Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import Type, List, NamedTuple
from typing import NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: Type[ComfyAPIBase]
api_class: type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = []
registered_versions: list[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]):
def register_versions(versions: list[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]:
def get_all_versions() -> list[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args, get_type_hints
from typing import Optional, get_origin, get_args, get_type_hints
class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
"""
Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod
def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker
cls, async_class: type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports
@classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub(
cls,
name: str,
attr: Type,
attr: type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed
@classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
"""
Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import Type, TypeVar
from typing import TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
)
return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None:
def inject_instance(cls: type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T:
def get_instance(cls: type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""

View File

@ -1,16 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
@ -80,7 +79,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod
@ -113,7 +112,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch
from typing import TypedDict, List, Optional
from typing import TypedDict, Optional
ImageInput = torch.Tensor
"""
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[List[int]]
batch_index: Optional[list[int]]

View File

@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
@ -337,7 +337,7 @@ class VideoFromComponents(VideoInput):
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
extra_kwargs = {}
if format != VideoContainer.AUTO:
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams

View File

@ -4,7 +4,8 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import asdict, dataclass
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@ -25,11 +26,9 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
from ._util import MESH, VOXEL, SVG as _SVG
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
class FolderType(str, Enum):
input = "input"
@ -76,16 +75,6 @@ class NumberDisplay(str, Enum):
slider = "slider"
class _StringIOType(str):
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class _ComfyType(ABC):
Type = Any
io_type: str = None
@ -125,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
new_cls.io_type = _StringIOType(io_type)
new_cls.io_type = io_type
if hasattr(new_cls, "Input") and new_cls.Input is not None:
new_cls.Input.Parent = new_cls
if hasattr(new_cls, "Output") and new_cls.Output is not None:
@ -150,6 +138,9 @@ class _IO_V3:
def __init__(self):
pass
def validate(self):
pass
@property
def io_type(self):
return self.Parent.io_type
@ -162,7 +153,7 @@ class Input(_IO_V3):
'''
Base class for a V3 Input.
'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__()
self.id = id
self.display_name = display_name
@ -170,6 +161,7 @@ class Input(_IO_V3):
self.tooltip = tooltip
self.lazy = lazy
self.extra_dict = extra_dict if extra_dict is not None else {}
self.rawLink = raw_link
def as_dict(self):
return prune_dict({
@ -177,10 +169,14 @@ class Input(_IO_V3):
"optional": self.optional,
"tooltip": self.tooltip,
"lazy": self.lazy,
"rawLink": self.rawLink,
}) | prune_dict(self.extra_dict)
def get_io_type(self):
return _StringIOType(self.io_type)
return self.io_type
def get_all(self) -> list[Input]:
return [self]
class WidgetInput(Input):
'''
@ -188,8 +184,8 @@ class WidgetInput(Input):
'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None,
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.default = default
self.socketless = socketless
self.widget_type = widget_type
@ -211,13 +207,14 @@ class Output(_IO_V3):
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
self.id = id
self.display_name = display_name
self.display_name = display_name if display_name else id
self.tooltip = tooltip
self.is_output_list = is_output_list
def as_dict(self):
display_name = self.display_name if self.display_name else self.id
return prune_dict({
"display_name": self.display_name,
"display_name": display_name,
"tooltip": self.tooltip,
"is_output_list": self.is_output_list,
})
@ -245,8 +242,8 @@ class Boolean(ComfyTypeIO):
'''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.label_on = label_on
self.label_off = label_off
self.default: bool
@ -265,8 +262,8 @@ class Int(ComfyTypeIO):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min
self.max = max
self.step = step
@ -291,8 +288,8 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min
self.max = max
self.step = step
@ -317,8 +314,8 @@ class String(ComfyTypeIO):
'''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.multiline = multiline
self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts
@ -351,12 +348,14 @@ class Combo(ComfyTypeIO):
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
extra_dict=None,
raw_link: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
self.multiselect = False
self.options = options
self.control_after_generate = control_after_generate
@ -380,10 +379,6 @@ class Combo(ComfyTypeIO):
super().__init__(id, display_name, tooltip, is_output_list)
self.options = options if options is not None else []
@property
def io_type(self):
return self.options
@comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
@ -392,8 +387,8 @@ class MultiCombo(ComfyTypeI):
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
socketless: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
self.multiselect = True
self.placeholder = placeholder
self.chip = chip
@ -426,9 +421,9 @@ class Webcam(ComfyTypeIO):
Type = str
def __init__(
self, id: str, display_name: str=None, optional=False,
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
@comfytype(io_type="MASK")
@ -561,6 +556,8 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
time_dim_replace: NotRequired[torch.Tensor]
'''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
@ -647,7 +644,7 @@ class Video(ComfyTypeIO):
@comfytype(io_type="SVG")
class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
Type = _SVG
@comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO):
@ -765,6 +762,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -772,7 +776,7 @@ class MultiType:
'''
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
'''
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
# if id is an Input, then use that Input with overridden values
self.input_override = None
if isinstance(id, Input):
@ -785,7 +789,7 @@ class MultiType:
# if is a widget input, make sure widget_type is set appropriately
if isinstance(self.input_override, WidgetInput):
self.input_override.widget_type = self.input_override.get_io_type()
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self._io_types = types
@property
@ -814,115 +818,326 @@ class MultiType:
else:
return super().as_dict()
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
self.template_id = template_id
# account for syntactic sugar
if not isinstance(allowed_types, Iterable):
allowed_types = [allowed_types]
for t in allowed_types:
if not isinstance(t, type):
if not isinstance(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
else:
if not issubclass(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
self.allowed_types = allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
}
class Input(Input):
def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
if not id and not display_name:
display_name = "MATCHTYPE"
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class DynamicInput(Input, ABC):
'''
Abstract class for dynamic input registration.
'''
@abstractmethod
def get_dynamic(self) -> list[Input]:
...
pass
class DynamicOutput(Output, ABC):
'''
Abstract class for dynamic output registration.
'''
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
pass
@abstractmethod
def get_dynamic(self) -> list[Output]:
...
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
if prefix_list is None:
prefix_list = []
if id is not None:
prefix_list = prefix_list + [id]
return prefix_list
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
assert not (prefix_list is None and id is None)
if prefix_list is None:
return id
elif id is not None:
prefix_list = prefix_list + [id]
return ".".join(prefix_list)
@comfytype(io_type="COMFY_AUTOGROW_V3")
class AutogrowDynamic(ComfyTypeI):
Type = list[Any]
class Input(DynamicInput):
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template_input = template_input
if min is not None:
assert(min >= 1)
if max is not None:
assert(max >= 1)
class Autogrow(ComfyTypeI):
Type = dict[str, Any]
_MaxNames = 100 # NOTE: max 100 names for sanity
class _AutogrowTemplate:
def __init__(self, input: Input):
# dynamic inputs are not allowed as the template input
assert(not isinstance(input, DynamicInput))
self.input = copy.copy(input)
if isinstance(self.input, WidgetInput):
self.input.force_input = True
self.names: list[str] = []
self.cached_inputs = {}
def _create_input(self, input: Input, name: str):
new_input = copy.copy(self.input)
new_input.id = name
return new_input
def _create_cached_inputs(self):
for name in self.names:
self.cached_inputs[name] = self._create_input(self.input, name)
def get_all(self) -> list[Input]:
return list(self.cached_inputs.values())
def as_dict(self):
return prune_dict({
"input": create_input_dict_v1([self.input]),
})
def validate(self):
self.input.validate()
class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input)
self.prefix = prefix
assert(min >= 0)
assert(max >= 1)
assert(max <= Autogrow._MaxNames)
self.min = min
self.max = max
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
self._create_cached_inputs()
def get_dynamic(self) -> list[Input]:
curr_count = 1
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
def as_dict(self):
return super().as_dict() | prune_dict({
"prefix": self.prefix,
"min": self.min,
"max": self.max,
})
class TemplateNames(_AutogrowTemplate):
def __init__(self, input: Input, names: list[str], min: int=1):
super().__init__(input)
self.names = names[:Autogrow._MaxNames]
assert(min >= 0)
self.min = min
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"names": self.names,
"min": self.min,
})
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
class ComboDynamic(ComfyTypeI):
class Input(DynamicInput):
def __init__(self, id: str):
pass
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
self.template_id = template_id
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
def get_all(self) -> list[Input]:
return [self] + self.template.get_all()
def validate(self):
self.template.validate()
@staticmethod
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
# need to figure out names based on template type
is_names = ("names" in value[1]["template"])
is_prefix = ("prefix" in value[1]["template"])
input = value[1]["template"]["input"]
if is_names:
min = value[1]["template"]["min"]
names = value[1]["template"]["names"]
max = len(names)
elif is_prefix:
prefix = value[1]["template"]["prefix"]
min = value[1]["template"]["min"]
max = value[1]["template"]["max"]
names = [f"{prefix}{i}" for i in range(max)]
# need to create a new input based on the contents of input
template_input = None
for _, dict_input in input.items():
# for now, get just the first value from dict_input
template_input = list(dict_input.values())[0]
new_dict = {}
for i, name in enumerate(names):
expected_id = finalize_prefix(curr_prefix, name)
if expected_id in live_inputs:
# required
if i < min:
type_dict = new_dict.setdefault("required", {})
# optional
else:
type_dict = new_dict.setdefault("optional", {})
type_dict[name] = template_input
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI):
Type = dict[str, Any]
class Option:
def __init__(self, key: str, inputs: list[Input]):
self.key = key
self.inputs = inputs
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": "".join(t.io_type for t in self.allowed_types),
"key": self.key,
"inputs": create_input_dict_v1(self.inputs),
}
class Input(DynamicInput):
def __init__(self, id: str, template: MatchType.Template,
def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
self.options = options
def get_dynamic(self) -> list[Input]:
return [self]
def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
"options": [o.as_dict() for o in self.options],
})
class Output(DynamicOutput):
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def validate(self):
# make sure all nested inputs are validated
for option in self.options:
for input in option.inputs:
input.validate()
def get_dynamic(self) -> list[Output]:
return [self]
@staticmethod
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
finalized_id = finalize_prefix(curr_prefix)
if finalized_id in live_inputs:
key = live_inputs[finalized_id]
selected_option = None
# get options from dict
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
for option in options:
if option["key"] == key:
selected_option = option
break
if selected_option is not None:
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
# add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
class DynamicSlot(ComfyTypeI):
Type = dict[str, Any]
class Input(DynamicInput):
def __init__(self, slot: Input, inputs: list[Input],
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
assert(not isinstance(slot, DynamicInput))
self.slot = copy.copy(slot)
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
optional = True
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
self.inputs = inputs
self.force_input = None
# force widget inputs to have no widgets, otherwise this would be awkward
if isinstance(self.slot, WidgetInput):
self.force_input = True
self.slot.force_input = True
def get_all(self) -> list[Input]:
return [self.slot] + self.inputs
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
"slotType": str(self.slot.get_io_type()),
"inputs": create_input_dict_v1(self.inputs),
"forceInput": self.force_input,
})
def validate(self):
self.slot.validate()
for input in self.inputs:
input.validate()
@staticmethod
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
finalized_id = finalize_prefix(curr_prefix)
if finalized_id in live_inputs:
inputs = value[1]["inputs"]
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
# add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
return DYNAMIC_INPUT_LOOKUP[io_type]
def setup_dynamic_input_funcs():
# Autogrow.Input
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
# DynamicCombo.Input
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs()
class V3Data(TypedDict):
hidden_inputs: dict[str, Any]
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
dynamic_paths: dict[str, Any]
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@ -958,6 +1173,10 @@ class HiddenHolder:
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
)
@classmethod
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
class Hidden(str, Enum):
'''
Enumerator for requesting hidden variables in nodes.
@ -984,6 +1203,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@ -1019,9 +1239,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
inputs: list[Input]=None
outputs: list[Output]=None
hidden: list[Hidden]=None
inputs: list[Input] = field(default_factory=list)
outputs: list[Output] = field(default_factory=list)
hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
@ -1061,60 +1281,57 @@ class Schema:
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
nested_inputs: list[Input] = []
for input in self.inputs:
if not isinstance(input, DynamicInput):
nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs]
output_ids = [o.id for o in self.outputs]
input_set = set(input_ids)
output_set = set(output_ids)
issues = []
issues: list[str] = []
# verify ids are unique per list
if len(input_set) != len(input_ids):
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
if len(output_set) != len(output_ids):
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
# verify ids are unique between lists
intersection = input_set & output_set
if len(intersection) > 0:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
# validate inputs and outputs
for input in self.inputs:
input.validate()
for output in self.outputs:
output.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
# ensure inputs, outputs, and hidden are lists
if self.inputs is None:
self.inputs = []
if self.outputs is None:
self.outputs = []
if self.hidden is None:
self.hidden = []
# if is an api_node, will need key-related hidden
if self.is_api_node:
if self.hidden is None:
self.hidden = []
if Hidden.auth_token_comfy_org not in self.hidden:
self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org)
# if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node:
if self.hidden is None:
self.hidden = []
if Hidden.prompt not in self.hidden:
self.hidden.append(Hidden.prompt)
if Hidden.extra_pnginfo not in self.hidden:
self.hidden.append(Hidden.extra_pnginfo)
# give outputs without ids default ids
if self.outputs is not None:
for i, output in enumerate(self.outputs):
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
for i, output in enumerate(self.outputs):
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls) -> NodeInfoV1:
# get V1 inputs
input = {
"required": {}
}
if self.inputs:
for i in self.inputs:
if isinstance(i, DynamicInput):
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
input = create_input_dict_v1(self.inputs)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1123,12 +1340,24 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
output_matchtypes = []
any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
# special handling for MatchType
if isinstance(o, MatchType.Output):
output_matchtypes.append(o.template.template_id)
any_matchtypes = True
else:
output_matchtypes.append(None)
# clear out lists that are all None
if not any_matchtypes:
output_matchtypes = None
info = NodeInfoV1(
input=input,
@ -1137,6 +1366,7 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@ -1181,17 +1411,84 @@ class Schema:
)
return info
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
out_dict = {
"required": {},
"optional": {},
"dynamic_paths": {},
}
d = d.copy()
# ignore hidden for parsing
hidden = d.pop("hidden", None)
parse_class_inputs(out_dict, live_inputs, d)
if hidden is not None and include_hidden:
out_dict["hidden"] = hidden
v3_data = {}
dynamic_paths = out_dict.pop("dynamic_paths", None)
if dynamic_paths is not None:
v3_data["dynamic_paths"] = dynamic_paths
return out_dict, hidden, v3_data
def add_to_dict_v1(i: Input, input: dict):
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
for input_type, inner_d in curr_dict.items():
for id, value in inner_d.items():
io_type = value[0]
if io_type in DYNAMIC_INPUT_LOOKUP:
# dynamic inputs need to be handled with lookup functions
dynamic_input_func = get_dynamic_input_func(io_type)
new_prefix = handle_prefix(curr_prefix, id)
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
else:
# non-dynamic inputs get directly transferred
finalized_id = finalize_prefix(curr_prefix, id)
out_dict[input_type][finalized_id] = value
if curr_prefix:
out_dict["dynamic_paths"][finalized_id] = finalized_id
def create_input_dict_v1(inputs: list[Input]) -> dict:
input = {
"required": {}
}
for i in inputs:
add_to_dict_v1(i, input)
return input
def add_to_dict_v1(i: Input, d: dict):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
as_dict.pop("optional", None)
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
if paths is None:
return values
values = values.copy()
result = {}
create_tuple = v3_data.get("create_dynamic_tuple", False)
for key, path in paths.items():
parts = path.split(".")
current = result
for i, p in enumerate(parts):
is_last = (i == len(parts) - 1)
if is_last:
value = values.pop(key, None)
if create_tuple:
value = (value, key)
current[p] = value
else:
current = current.setdefault(p, {})
values.update(result)
return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@ -1201,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
SCHEMA = None
# filled in during execution
resources: Resources = None
hidden: HiddenHolder = None
@classmethod
@ -1248,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
return [name for name in kwargs if kwargs[name] is None]
def __init__(self):
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS()
@classmethod
@ -1311,12 +1606,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
return type_clone
@final
@ -1433,15 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
return input, schema
return input
return info.input
@final
@classmethod
@ -1513,7 +1803,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError
@classmethod
def validate_inputs(cls, **kwargs) -> bool:
def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@ -1560,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
return self.args if len(self.args) > 0 else None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
args = ()
ui = None
expand = None
@ -1573,7 +1863,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
return cls(args=args, ui=ui, expand=expand)
return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any:
return self.args[index]
@ -1628,6 +1918,7 @@ __all__ = [
"StyleModel",
"Gligen",
"UpscaleModel",
"LatentUpscaleModel",
"Audio",
"Video",
"SVG",
@ -1651,6 +1942,11 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
"Tracks",
# Dynamic Types
"MatchType",
"DynamicCombo",
"Autogrow",
# Other classes
"HiddenHolder",
"Hidden",
@ -1661,4 +1957,5 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
]

View File

@ -0,0 +1 @@
from ._io import * # noqa: F403

View File

@ -1,72 +0,0 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

View File

@ -3,8 +3,8 @@ from __future__ import annotations
import json
import os
import random
import uuid
from io import BytesIO
from typing import Type
import av
import numpy as np
@ -21,7 +21,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
@ -82,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -95,7 +95,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -120,7 +120,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
@ -136,7 +136,7 @@ class ImageSaveHelper:
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -154,7 +154,7 @@ class ImageSaveHelper:
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
@ -168,7 +168,7 @@ class ImageSaveHelper:
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -190,7 +190,7 @@ class ImageSaveHelper:
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
@ -208,7 +208,7 @@ class ImageSaveHelper:
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -237,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -266,7 +266,7 @@ class AudioSaveHelper:
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
@ -318,9 +318,10 @@ class AudioSaveHelper:
for key, value in metadata.items():
output_container.metadata[key] = value
layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@ -332,7 +333,7 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@ -341,12 +342,12 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
layout=layout,
)
frame.sample_rate = sample_rate
frame.pts = 0
@ -370,7 +371,7 @@ class AudioSaveHelper:
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
@ -386,7 +387,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -410,7 +411,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
@ -436,9 +437,19 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
self.bg_image_path = None
bg_image = kwargs.get("bg_image", None)
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = PILImage.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
self.bg_image_path = f"temp/{filename}"
def as_dict(self):
return {"result": [self.model_file, self.camera_info]}
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewText(_UIOutput):

View File

@ -0,0 +1 @@
from ._ui import * # noqa: F403

View File

@ -1,5 +1,6 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH
from .image_types import SVG
__all__ = [
# Utility Types
@ -8,4 +9,5 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
"SVG",
]

View File

@ -0,0 +1,18 @@
from io import BytesIO
class SVG:
"""Stores SVG representations via a list of BytesIO objects."""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput
from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"

View File

@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -42,4 +42,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"IO",
"ui",
"UI",
]

View File

@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [
supported_versions: list[type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,

View File

@ -0,0 +1,144 @@
from typing import Literal
from pydantic import BaseModel, Field
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(False)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(False)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str = Field("url")
image: list[str] | None = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(False)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
"seedance-1-0-pro-fast-251015": {
"480p": 50,
"720p": 65,
"1080p": 100,
},
}

View File

@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
class GeminiFunctionDeclaration(BaseModel):
@ -141,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
uploadImagesToStorage: bool = Field(True)
class GeminiGenerateContentRequest(BaseModel):

View File

@ -46,21 +46,68 @@ class TaskStatusVideoResult(BaseModel):
url: str | None = Field(None, description="URL for generated video")
class TaskStatusVideoResults(BaseModel):
class TaskStatusImageResult(BaseModel):
index: int = Field(..., description="Image Number0-9")
url: str = Field(..., description="URL for generated image")
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusVideoResponseData(BaseModel):
class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: TaskStatusVideoResults | None = Field(None)
task_result: TaskStatusResults | None = Field(None)
class TaskStatusVideoResponse(BaseModel):
class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: TaskStatusVideoResponseData | None = Field(None)
data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
image: str = Field(...)
class OmniProImageRequest(BaseModel):
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class MotionControlRequest(BaseModel):
prompt: str = Field(...)
image_url: str = Field(...)
video_url: str = Field(...)
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")

View File

@ -0,0 +1,52 @@
from pydantic import BaseModel, Field
class Datum2(BaseModel):
b64_json: str | None = Field(None, description="Base64 encoded image data")
revised_prompt: str | None = Field(None, description="Revised prompt")
url: str | None = Field(None, description="URL of the image")
class InputTokensDetails(BaseModel):
image_tokens: int | None = None
text_tokens: int | None = None
class Usage(BaseModel):
input_tokens: int | None = None
input_tokens_details: InputTokensDetails | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class OpenAIImageGenerationResponse(BaseModel):
data: list[Datum2] | None = None
usage: Usage | None = None
class OpenAIImageEditRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str = Field(...)
moderation: str | None = Field(None)
n: int | None = Field(None, description="The number of images to generate")
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
size: str | None = Field(None, description="Size of the output image")
class OpenAIImageGenerationRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str | None = Field(None)
moderation: str | None = Field(None)
n: int | None = Field(
None,
description="The number of images to generate.",
)
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="The quality of the generated image")
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")

View File

@ -1,100 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel):
root: Union[

View File

@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel):

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch
from pydantic import BaseModel
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
)
def convert_mask_to_image(mask: torch.Tensor):
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[
IO.String.Input(
"prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False,
raw: bool = False,
seed: int = 0,
image_prompt: torch.Tensor | None = None,
image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1,
) -> IO.NodeOutput:
if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
"prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str,
guidance: float,
steps: int,
input_image: torch.Tensor | None = None,
input_image: Input.Image | None = None,
seed=0,
prompt_upsampling=False,
) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "")
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Outpaints image based on prompt.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
prompt_upsampling: bool,
top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode",
display_name="Flux.1 Fill Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Inpaints image based on mask and prompt.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
mask: torch.Tensor,
image: Input.Image,
mask: Input.Image,
prompt: str,
prompt_upsampling: bool,
steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Flux2ProImageNode",
display_name="Flux.2 [pro] Image",
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
default=True,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
"If active, automatically modifies the prompt for more creative generation.",
),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
],
outputs=[IO.Image.Output()],
hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int,
seed: int,
prompt_upsampling: bool,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
) -> IO.NodeOutput:
reference_images = {}
if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest(
prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode,
FluxProFillNode,
Flux2ProImageNode,
Flux2MaxImageNode,
]

View File

@ -1,13 +1,27 @@
import logging
import math
from enum import Enum
from typing import Literal, Optional, Union
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,
Image2ImageTaskCreationRequest,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
class Text2ImageModelName(str, Enum):
seedream_3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit_3 = "seededit-3-0-i2i-250628"
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field("seedream-4-0-250828")
prompt: str = Field(...)
response_format: str = Field("url")
image: Optional[list[str]] = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: Optional[TaskStatusError] = Field(None)
content: Optional[TaskStatusResult] = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
}
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "content") and response.content:
return response.content.video_url
return None
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.String.Input(
"prompt",
multiline=True,
@ -266,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@ -374,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
@ -388,13 +229,14 @@ class ByteDanceImageEditNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
seed: int,
guidance_scale: float,
@ -428,13 +270,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4",
display_name="ByteDance Seedream 4.5",
category="api node/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.Combo.Input(
"model",
options=["seedream-4-0-250828"],
options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@ -459,7 +301,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -468,7 +310,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -505,7 +347,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True,
),
@ -532,14 +374,14 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor = None,
image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
sequential_image_generation: str = "disabled",
max_images: int = 1,
seed: int = 0,
watermark: bool = True,
watermark: bool = False,
fail_on_partial: bool = True,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -555,6 +397,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@ -607,9 +461,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -655,7 +508,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -714,9 +567,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -766,7 +618,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -787,7 +639,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor,
image: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -833,9 +685,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -889,7 +740,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -910,8 +761,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -968,9 +819,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[Image2VideoModelName.seedance_1_lite.value],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -1013,7 +863,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -1034,7 +884,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
images: torch.Tensor,
images: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -1069,8 +919,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
estimated_duration: Optional[int],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@ -1085,7 +935,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:

View File

@ -13,8 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
@ -27,12 +26,15 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -43,6 +45,14 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
class GeminiModel(str, Enum):
@ -68,7 +78,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: torch.Tensor,
images: Input.Image,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
@ -132,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
if part_type == "text" and part.text:
parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@ -154,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4))
@ -277,6 +292,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.String.Output(),
@ -293,7 +315,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
@ -343,10 +367,11 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@ -363,7 +388,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
# Create response
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -373,7 +401,8 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
]
],
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -523,6 +552,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -542,10 +578,11 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@ -559,9 +596,13 @@ class GeminiImage(IO.ComfyNode):
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -570,11 +611,12 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@ -640,6 +682,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -662,8 +711,9 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -679,9 +729,13 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -690,11 +744,12 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
import logging
import math
import re
import torch
from typing_extensions import override
@ -49,12 +50,17 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
MotionControlRequest,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
OmniProFirstLastFrameRequest,
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
TaskStatusVideoResponse,
TaskStatusResponse,
TextToVideoWithAudioRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -100,10 +106,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -124,8 +126,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -210,7 +210,36 @@ VOICES_CONFIG = {
}
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
def normalize_omni_prompt_references(prompt: str) -> str:
"""
Rewrites Kling Omni-style placeholders used in the app, like:
@image, @image1, @image2, ... @imageN
@video, @video1, @video2, ... @videoN
into the API-compatible form:
<<<image_1>>>, <<<image_2>>>, ...
<<<video_1>>>, <<<video_2>>>, ...
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
"""
if not prompt:
return prompt
def _image_repl(match):
return f"<<<image_{match.group('idx') or '1'}>>>"
def _video_repl(match):
return f"<<<video_{match.group('idx') or '1'}>>>"
# (?<!\w) avoids matching e.g. "test@image.com"
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -218,8 +247,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusVideoResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -450,12 +480,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@ -719,7 +749,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[4],
default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -777,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -796,17 +827,19 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt: str,
aspect_ratio: str,
duration: int,
resolution: str = "1080p",
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=TaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -829,7 +862,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("duration", options=["5", "10"]),
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
@ -842,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True,
tooltip="Up to 6 additional reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -863,10 +897,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
first_frame: Input.Image,
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
if duration not in (5, 10) and end_frame is None and reference_images is None:
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [
@ -895,12 +935,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -929,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images",
tooltip="Up to 7 reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -949,7 +991,9 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
reference_images: Input.Image,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
@ -962,13 +1006,14 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -1000,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
tooltip="Up to 4 additional reference images.",
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -1022,7 +1068,9 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
reference_video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1045,7 +1093,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1053,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list if image_list else None,
video_list=video_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -1082,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
tooltip="Up to 4 additional reference images.",
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -1102,7 +1152,9 @@ class OmniProEditVideoNode(IO.ComfyNode):
video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1125,7 +1177,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1133,11 +1185,96 @@ class OmniProEditVideoNode(IO.ComfyNode):
duration=None,
image_list=image_list if image_list else None,
video_list=video_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
class OmniProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling Omni Image (Pro)",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
resolution: str,
aspect_ratio: str,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 10:
raise ValueError("The maximum number of reference images is 10.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=TaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@ -1207,9 +1344,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image to Video",
display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1367,7 +1503,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[8],
default=modes[6],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -1830,7 +1966,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input(
"model_name",
options=[i.value for i in KlingImageGenModelName],
default="kling-v1",
default="kling-v2",
),
IO.Combo.Input(
"aspect_ratio",
@ -1913,6 +2049,221 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class MotionControl(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="api node/video/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
IO.Video.Input(
"reference_video",
tooltip="Motion reference video used to drive movement/expression.\n"
"Duration limits depend on character_orientation:\n"
" - image: 310s (max 10s)\n"
" - video: 330s (max 30s)",
),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Combo.Input(
"character_orientation",
options=["video", "image"],
tooltip="Controls where the character's facing/orientation comes from.\n"
"video: movements, expressions, camera moves, and orientation "
"follow the motion reference video (other details via prompt).\n"
"image: movements and expressions still follow the motion reference video, "
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
reference_image: Input.Image,
reference_video: Input.Video,
keep_original_sound: bool,
character_orientation: str,
mode: str,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
if character_orientation == "image":
validate_video_duration(reference_video, min_duration=3, max_duration=10)
else:
validate_video_duration(reference_video, min_duration=3, max_duration=30)
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
response_model=TaskStatusResponse,
data=MotionControlRequest(
prompt=prompt,
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
video_url=await upload_video_to_comfyapi(cls, reference_video),
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1935,6 +2286,10 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
MotionControl,
]

View File

@ -1,12 +1,9 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
fps: int | None = Field(25)
generate_audio: bool | None = Field(True)
image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode):
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
model: str,
prompt: str,
duration: int,
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):

View File

@ -1,11 +1,8 @@
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:

View File

@ -1,46 +1,45 @@
from io import BytesIO
import base64
import os
from enum import Enum
from inspect import cleandoc
from io import BytesIO
import numpy as np
import torch
from PIL import Image
import folder_paths
import base64
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties,
Item,
OutputContent,
InputImageContent,
Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent,
InputFileContent,
InputImageContent,
InputMessage,
InputMessageContentList,
InputTextContent,
Item,
OpenAICreateResponse,
OpenAIResponse,
OutputContent,
)
from comfy_api_nodes.apis.openai_api import (
OpenAIImageEditRequest,
OpenAIImageGenerationRequest,
OpenAIImageGenerationResponse,
)
from comfy_api_nodes.util import (
downscale_image_tensor,
download_url_to_bytesio,
validate_string,
tensor_to_base64_string,
ApiEndpoint,
sync_op,
download_url_to_bytesio,
downscale_image_tensor,
poll_op,
sync_op,
tensor_to_base64_string,
text_filepath_to_data_uri,
validate_string,
)
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
class OpenAIDalle2(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
"""
@classmethod
def define_schema(cls):
@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode):
node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode):
class OpenAIDalle3(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
"""
@classmethod
def define_schema(cls):
@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode):
node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response))
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
# https://platform.openai.com/docs/pricing
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
"""
@classmethod
def define_schema(cls):
@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt for GPT Image 1",
tooltip="Text prompt for GPT Image",
),
IO.Int.Input(
"seed",
@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
),
IO.Combo.Input(
"background",
default="opaque",
options=["opaque", "transparent"],
default="auto",
options=["auto", "opaque", "transparent"],
tooltip="Return image with or without background",
optional=True,
),
@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True,
),
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5"],
optional=True,
),
],
outputs=[
IO.Image.Output(),
@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
@classmethod
async def execute(
cls,
prompt,
seed=0,
quality="low",
background="opaque",
image=None,
mask=None,
n=1,
size="1024x1024",
prompt: str,
seed: int = 0,
quality: str = "low",
background: str = "opaque",
image: Input.Image | None = None,
mask: Input.Image | None = None,
n: int = 1,
size: str = "1024x1024",
model: str = "gpt-image-1",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
files = []
if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image")
if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5
else:
raise ValueError(f"Unknown model: {model}")
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type = "multipart/form-data"
files = []
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i : i + 1]
scaled_image = downscale_image_tensor(single_image).squeeze()
single_image = image[i: i + 1]
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image is None:
raise Exception("Cannot use a mask without an input image")
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
batch, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
_, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
# Build the operation
response = await sync_op(
cls,
ApiEndpoint(path=path, method="POST"),
response_model=OpenAIImageGenerationResponse,
data=request_class(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
),
files=files if files else None,
content_type=content_type,
)
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=price_extractor,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
price_extractor=price_extractor,
)
return IO.NodeOutput(await validate_and_cast_response(response))

View File

@ -1,568 +0,0 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

Some files were not shown because too many files have changed in this diff Show More